Skip to content

Commit 8c1f8dc

Browse files
authored
Merge pull request #23 from stevehuang52/add_nemo_chunk_infer
add nemo chunk rtf cal
2 parents a0b85c8 + cc33886 commit 8c1f8dc

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

nemo_asr/calc_rtf_chunk.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""
2+
NeMo ASR Model Profiler
3+
4+
This script performs a forward pass on an NeMo ASR models and measures its real-time factor (RTF).
5+
RTF is a metric used to evaluate the processing speed of ASR models.
6+
7+
# audio has to be a mono wav file with 16kHz sample rate
8+
9+
Parameters:
10+
--model: ASR model name or path to the model checkpoint file.
11+
--decoding_type: Type of decoding to use (ctc or rnnt).
12+
--gpu: GPU device to use.
13+
--batch_size: Batch size to use for inference.
14+
--nbatches: Total number of batches to process.
15+
--warmup_batches: Number of batches to skip as warmup.
16+
--audio: Path to the input audio file for ASR.
17+
18+
Example:
19+
python calc_rtf.py --model stt_en_conformer_ctc_large --decoding_type ctc
20+
"""
21+
import math
22+
import copy
23+
from omegaconf import OmegaConf
24+
import time
25+
import argparse
26+
from tqdm import tqdm
27+
import torch
28+
import sys
29+
import soundfile as sf
30+
import numpy as np
31+
import librosa
32+
33+
from nemo.utils import logging
34+
from nemo.collections.asr.models import ASRModel
35+
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecodingConfig
36+
from nemo.collections.asr.parts.utils.streaming_utils import AudioFeatureIterator, FrameBatchChunkedCTC, FrameBatchChunkedRNNT, FrameBatchMultiTaskAED
37+
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
38+
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
39+
40+
parser = argparse.ArgumentParser(description='model forward pass profiler / performance tester.')
41+
parser.add_argument("--model", default='nvidia/canary-1b', type=str, help="ASR model")
42+
parser.add_argument("--decoding_type", default='aed', type=str, help="Type of model [rnnt, ctc, aed]")
43+
parser.add_argument("--gpu", default=0, type=int, help="GPU device to use")
44+
parser.add_argument("--batch_size", default=1, type=int, help="batch size to use")
45+
parser.add_argument("--nbatches", default=3, type=int, help="Total Number of batches to process")
46+
parser.add_argument("--warmup_batches", default=3, type=int, help="Number of batches to skip as warmup")
47+
parser.add_argument("--audio", default="../data/sample_4469669.wav", type=str, help="wav file to use")
48+
# parser.add_argument("--audio_maxlen", default=30, type=float, help="Multiple chunks of audio of this length is used to calculate RTFX")
49+
50+
51+
args = parser.parse_args()
52+
torch.backends.cudnn.benchmark=True
53+
54+
WAV = args.audio
55+
SAMPLING_RATE = 16000
56+
chunk_len = 30
57+
total_audio_len = 600
58+
MODEL = args.model
59+
batch_size = args.batch_size
60+
nbatches = args.nbatches
61+
warmup_batches = args.warmup_batches
62+
decoding_type = args.decoding_type
63+
model_stride = 8 # 8 for fastconformer and citrinet, 4 for conformer
64+
chunk_batch_size = 24 # number of chunks to run in parallel
65+
DEVICE=torch.device(args.gpu)
66+
logging.info(f'MODEL: {MODEL}')
67+
68+
def get_samples(audio_file, total_audio_len, target_sr=16000):
69+
with sf.SoundFile(audio_file, 'r') as f:
70+
dtype = 'int16'
71+
sample_rate = f.samplerate
72+
samples = f.read(dtype=dtype)
73+
if sample_rate != target_sr:
74+
samples = librosa.core.resample(samples, orig_sr=sample_rate, target_sr=target_sr)
75+
samples = samples.astype('float32') / 32768
76+
samples = samples.transpose()
77+
sample_length = samples.shape[0]
78+
if sample_length > total_audio_len * target_sr:
79+
logging.info(f'resizing audio sample from {sample_length / target_sr} to maxlen of {total_audio_len}')
80+
sample_length = int(total_audio_len * target_sr)
81+
samples = samples[:sample_length]
82+
logging.info(f'new sample lengh: {samples.shape[0]}')
83+
else:
84+
pad_length = int(total_audio_len * target_sr) - sample_length
85+
logging.info(f'padding audio sample from {sample_length / target_sr} to maxlen of {total_audio_len}')
86+
samples = np.pad(samples, (0, pad_length), 'constant', constant_values=(0, 0))
87+
sample_length = int(total_audio_len * target_sr)
88+
89+
return samples, sample_length
90+
91+
92+
def extract_preprocessor(model, device):
93+
cfg = copy.deepcopy(model._cfg)
94+
OmegaConf.set_struct(cfg.preprocessor, False)
95+
cfg.preprocessor.dither = 0.0
96+
cfg.preprocessor.pad_to = 0
97+
preprocessor = model.from_config_dict(cfg.preprocessor)
98+
return preprocessor.to(device)
99+
100+
def setup_aed_decoding(asr_model):
101+
decoding_cfg = MultiTaskDecodingConfig()
102+
decoding_cfg.strategy = "beam"
103+
decoding_cfg.beam.beam_size = 1
104+
asr_model.change_decoding_strategy(decoding_cfg)
105+
106+
def setup_rnnt_decoding(asr_model):
107+
decoding_cfg = RNNTDecodingConfig()
108+
decoding_cfg.strategy = "greedy_batch"
109+
if hasattr(asr_model, 'cur_decoder'):
110+
asr_model.change_decoding_strategy(decoding_cfg, decoder_type="rnnt")
111+
else:
112+
asr_model.change_decoding_strategy(decoding_cfg)
113+
114+
def setup_ctc_decoding(asr_model):
115+
decoding_cfg = CTCDecodingConfig()
116+
decoding_cfg.strategy = "greedy"
117+
if hasattr(asr_model, 'cur_decoder'):
118+
asr_model.change_decoding_strategy(decoding_cfg, decoder_type="ctc")
119+
else:
120+
asr_model.change_decoding_strategy(decoding_cfg)
121+
122+
def setup_rnnt_chunk_infer(frame_asr, audio_input):
123+
frame_reader = AudioFeatureIterator(audio_input, frame_asr.frame_len, frame_asr.raw_preprocessor, frame_asr.asr_model.device)
124+
frame_asr.set_frame_reader(frame_reader)
125+
126+
def setup_aed_chunk_infer(frame_asr, audio_input, meta_data):
127+
frame_asr.input_tokens = frame_asr.get_input_tokens(meta_data)
128+
frame_reader = AudioFeatureIterator(audio_input, frame_asr.frame_len, frame_asr.raw_preprocessor, frame_asr.asr_model.device)
129+
frame_asr.set_frame_reader(frame_reader)
130+
131+
def setup_ctc_chunk_infer(frame_asr, audio_input):
132+
frame_reader = AudioFeatureIterator(audio_input, frame_asr.frame_len, frame_asr.raw_preprocessor, frame_asr.asr_model.device)
133+
frame_asr.set_frame_reader(frame_reader)
134+
135+
136+
def main():
137+
if MODEL.endswith('.nemo'):
138+
asr_model = ASRModel.restore_from(MODEL)
139+
else:
140+
asr_model = ASRModel.from_pretrained(MODEL)
141+
142+
asr_model.to(DEVICE)
143+
asr_model.eval()
144+
asr_model._prepare_for_export()
145+
146+
input_example, input_example_length = get_samples(WAV, total_audio_len)
147+
148+
frame_asr = None
149+
if decoding_type == 'aed':
150+
setup_aed_decoding(asr_model)
151+
frame_asr = FrameBatchMultiTaskAED(
152+
asr_model=asr_model,
153+
frame_len=chunk_len,
154+
total_buffer=chunk_len,
155+
batch_size=chunk_batch_size,
156+
)
157+
elif decoding_type == 'rnnt':
158+
setup_rnnt_decoding(asr_model)
159+
frame_asr = FrameBatchChunkedRNNT(
160+
asr_model=asr_model,
161+
frame_len=chunk_len,
162+
total_buffer=chunk_len,
163+
batch_size=chunk_batch_size,
164+
)
165+
elif decoding_type == 'ctc':
166+
setup_ctc_decoding(asr_model)
167+
frame_asr = FrameBatchChunkedCTC(
168+
asr_model=asr_model,
169+
frame_len=chunk_len,
170+
total_buffer=chunk_len,
171+
batch_size=chunk_batch_size,
172+
)
173+
else:
174+
raise ValueError(f'Invalid decoding type: {decoding_type}, must be one of [ctc, rnnt, aed]')
175+
176+
177+
logging.info(f"running {nbatches} batches; with {warmup_batches} batches warmup; batch_size: {batch_size}")
178+
rtfs=[]
179+
for i in range(3): # average over 3 runs
180+
total_time = 0
181+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
182+
with torch.no_grad():
183+
for i in tqdm(range(nbatches + warmup_batches)):
184+
frame_asr.reset()
185+
start = time.time()
186+
if decoding_type == 'ctc':
187+
setup_ctc_chunk_infer(frame_asr, input_example)
188+
elif decoding_type == 'rnnt':
189+
setup_rnnt_chunk_infer(frame_asr, input_example)
190+
elif decoding_type == 'aed':
191+
meta = {
192+
'audio_filepath': WAV,
193+
'duration': total_audio_len,
194+
'source_lang': 'en',
195+
'taskname': 'asr',
196+
'target_lang': 'en',
197+
'pnc': 'yes',
198+
'answer': 'nvidia',
199+
}
200+
setup_aed_chunk_infer(frame_asr, input_example, meta)
201+
else:
202+
raise ValueError(f'Invalid decoding type: {decoding_type}')
203+
hyp = frame_asr.transcribe()
204+
torch.cuda.synchronize()
205+
end = time.time()
206+
if i >= warmup_batches:
207+
total_time += end - start
208+
209+
rtf = (total_time/nbatches) / (float(input_example_length) / 16000)
210+
211+
rtfs.append(rtf)
212+
213+
print(f'RTF: {rtfs}')
214+
rtf = sum(rtfs)/len(rtfs)
215+
sys.stdout.write(f'{rtf:.4f}\n')
216+
217+
if __name__ == '__main__':
218+
main()

0 commit comments

Comments
 (0)