1414 --nbatches: Total number of batches to process.
1515 --warmup_batches: Number of batches to skip as warmup.
1616 --audio: Path to the input audio file for ASR.
17- --audio_maxlen: Maximum duration of audio to process (in seconds).
18- --precision: Model precision (16, 32, or bf16).
19- --cudnn_benchmark: Enable cuDNN benchmarking.
20- --log: Enable logging.
2117
2218Example:
23- python calculate_rtf .py --model stt_en_conformer_ctc_large --decoding_type ctc --gpu 0 --batch_size 1 --nbatches 5 --warmup_batches 5 --audio /path/to/audio.wav --audio_maxlen 600 --precision bf16 --cudnn_benchmark
19+ python calc_rtf .py --model stt_en_conformer_ctc_large --decoding_type ctc
2420"""
25-
21+ import copy
22+ from omegaconf import OmegaConf
2623import time
2724import argparse
2825from tqdm import tqdm
2926import torch
30- from omegaconf import OmegaConf
31- import copy
3227import sys
3328import soundfile as sf
3429import numpy as np
30+ import librosa
3531
3632from nemo .utils import logging
37- from contextlib import nullcontext
3833from nemo .collections .asr .models import ASRModel
3934
4035
4136
4237parser = argparse .ArgumentParser (description = 'model forward pass profiler / performance tester.' )
43- parser .add_argument ("--model" , default = 'stt_en_conformer_ctc_large ' , type = str , help = "ASR model" )
44- parser .add_argument ("--decoding_type" , default = 'ctc' , type = str , help = "Encoding type (bpe or char) " )
38+ parser .add_argument ("--model" , default = 'stt_en_fastconformer_ctc_large ' , type = str , help = "ASR model" )
39+ parser .add_argument ("--decoding_type" , default = 'ctc' , type = str , help = "Type of model [rnnt, ctc, aed] " )
4540parser .add_argument ("--gpu" , default = 0 , type = int , help = "GPU device to use" )
4641parser .add_argument ("--batch_size" , default = 1 , type = int , help = "batch size to use" )
47- parser .add_argument ("--nbatches" , default = 5 , type = int , help = "Total Number of batches to process" )
48- parser .add_argument ("--warmup_batches" , default = 5 , type = int , help = "Number of batches to skip as warmup" )
49- parser .add_argument ("--audio" , default = "/disk3/datasets/speech-datasets/earnings22/media/4469669.wav" , type = str , help = "wav file to use" )
50- parser .add_argument ("--audio_maxlen" , default = 16 , type = float , help = "cut the file at given length if it is longer" )
51- parser .add_argument ("--precision" , default = 'bf16' , type = str , help = "precision: 16/32/bf16" )
52- parser .add_argument ("--cudnn_benchmark" , dest = "enable_cudnn_bench" , action = "store_true" , help = "toggle cudnn benchmarking" , default = True )
53- parser .add_argument ("--log" , dest = "log" , action = "store_true" , help = "toggle logging" , default = True )
42+ parser .add_argument ("--nbatches" , default = 3 , type = int , help = "Total Number of batches to process" )
43+ parser .add_argument ("--warmup_batches" , default = 3 , type = int , help = "Number of batches to skip as warmup" )
44+ parser .add_argument ("--audio" , default = "../data/sample_4469669.wav" , type = str , help = "wav file to use" )
5445
55- args = parser .parse_args ()
56-
57- if args .log :
58- # INFO
59- logging .setLevel (20 )
60- else :
61- logging .setLevel (0 )
46+ # parser.add_argument("--audio_maxlen", default=30, type=float, help="Multiple chunks of audio of this length is used to calculate RTFX")
6247
63- if args . enable_cudnn_bench :
64- torch .backends .cudnn .benchmark = True
48+ args = parser . parse_args ()
49+ torch .backends .cudnn .benchmark = True
6550
66- PRECISION = args .precision
6751WAV = args .audio
68- audio_maxlen = args .audio_maxlen
52+ SAMPLING_RATE = 16000
53+ chunk_len = 30
54+ total_audio_len = 600
6955MODEL = args .model
7056batch_size = args .batch_size
7157nbatches = args .nbatches
7460
7561DEVICE = torch .device (args .gpu )
7662
77- if PRECISION != 'bf16' and PRECISION != '16' and PRECISION != '32' :
78- logging .error (f'unknown precision: { PRECISION } ' )
79- sys .exit (1 )
80-
81- logging .info (f'precision: { PRECISION } ' )
82- logging .info (f'WAV: { WAV } ' )
83- logging .info (f'AUDIO MAXLEN: { audio_maxlen } ' )
8463logging .info (f'MODEL: { MODEL } ' )
85- logging .info (f'batch_size: { batch_size } ' )
86- logging .info (f'num batches: { nbatches } ' )
87- logging .info (f'cudnn_benchmark: { args .enable_cudnn_bench } ' )
88-
8964
90- def get_samples (audio_file , audio_maxlen , target_sr = 16000 ):
65+ def get_samples (audio_file , total_audio_len , target_sr = 16000 ):
9166 with sf .SoundFile (audio_file , 'r' ) as f :
9267 dtype = 'int16'
9368 sample_rate = f .samplerate
9469 samples = f .read (dtype = dtype )
9570 if sample_rate != target_sr :
96- samples = librosa .core .resample (samples , sample_rate , target_sr )
71+ samples = librosa .core .resample (samples , orig_sr = sample_rate , target_sr = target_sr )
9772 samples = samples .astype ('float32' ) / 32768
9873 samples = samples .transpose ()
9974 sample_length = samples .shape [0 ]
100- if sample_length > audio_maxlen * target_sr :
101- logging .info (f'resizing audio sample from { sample_length / target_sr } to maxlen of { audio_maxlen } ' )
102- sample_length = int (audio_maxlen * target_sr )
75+ if sample_length > total_audio_len * target_sr :
76+ logging .info (f'resizing audio sample from { sample_length / target_sr } to maxlen of { total_audio_len } ' )
77+ sample_length = int (total_audio_len * target_sr )
10378 samples = samples [:sample_length ]
10479 logging .info (f'new sample lengh: { samples .shape [0 ]} ' )
10580 else :
106- pad_length = int (audio_maxlen * target_sr ) - sample_length
107- logging .info (f'padding audio sample from { sample_length / target_sr } to maxlen of { audio_maxlen } ' )
81+ pad_length = int (total_audio_len * target_sr ) - sample_length
82+ logging .info (f'padding audio sample from { sample_length / target_sr } to maxlen of { total_audio_len } ' )
10883 samples = np .pad (samples , (0 , pad_length ), 'constant' , constant_values = (0 , 0 ))
109- sample_length = int (audio_maxlen * target_sr )
84+ sample_length = int (total_audio_len * target_sr )
11085
11186 return samples , sample_length
11287
113- def preprocess_audio (preprocessor , audio , device ):
114- audio_signal = torch .from_numpy (audio ).unsqueeze_ (0 ).to (device )
115-
116- audio_signal_len = torch .Tensor ([audio .shape [0 ]]).to (device )
117- processed_signal , processed_signal_length = preprocessor (
118- input_signal = audio_signal , length = audio_signal_len
119- )
120- return processed_signal , processed_signal_length
12188
12289def extract_preprocessor (model , device ):
12390 cfg = copy .deepcopy (model ._cfg )
@@ -135,64 +102,58 @@ def main():
135102 asr_model = ASRModel .from_pretrained (MODEL )
136103
137104 asr_model .to (DEVICE )
138- asr_model .encoder .eval ()
139- asr_model .encoder .freeze ()
105+ asr_model .eval ()
140106 asr_model ._prepare_for_export ()
141107
142- processor = extract_preprocessor (asr_model , DEVICE )
143-
144- input_example , input_example_length = get_samples (WAV , audio_maxlen )
145- logging .info (f'processed example shape: { input_example .shape } ' )
146- logging .info (f'processed example length shape: { input_example_length } ' )
147- processed_example , processed_example_length = preprocess_audio (processor , input_example , DEVICE )
148- processed_example = processed_example .repeat (batch_size , 1 , 1 )
149- processed_example_length = processed_example_length .repeat (batch_size )
150- logging .info (f'processed example shape: { processed_example .size ()} ' )
151- logging .info (f'processed example length shape: { processed_example_length .size ()} ' )
152-
153- profiling_context = nullcontext ()
154- # if FP16:
155- if PRECISION == '16' :
156- precision_context = torch .cuda .amp .autocast ()
157- elif PRECISION == 'bf16' :
158- precision_context = torch .cuda .amp .autocast (dtype = torch .bfloat16 )
159- elif PRECISION == '32' :
160- pass
161- else :
162- logging .error (f'unknown precision: { PRECISION } ' )
163- sys .exit (1 )
164-
108+ preprocessor = extract_preprocessor (asr_model , DEVICE )
109+ input_example , input_example_length = get_samples (WAV , total_audio_len )
110+ input_example = torch .tensor (input_example ).to (DEVICE )
111+ input_example = input_example .repeat (batch_size , 1 )
112+ input_example_length = torch .tensor (input_example_length ).to (DEVICE )
113+ input_example_length = input_example_length .repeat (batch_size )
165114
166- if decoding_type == 'ctc' :
167- asr_model . change_decoding_strategy ( decoding_cfg = None )
115+ processed_signal , processed_signal_length = preprocessor ( input_signal = input_example , length = input_example_length )
116+ processed_example = processed_signal . repeat ( batch_size , 1 , 1 )
168117
118+
169119 logging .info (f"running { nbatches } batches; with { warmup_batches } batches warmup; batch_size: { batch_size } " )
170120 rtfs = []
171121 for i in range (3 ): # average over 3 runs
172122 total_time = 0
173- with profiling_context :
174- with precision_context :
175- with torch .no_grad ():
176- for i in tqdm (range (nbatches + warmup_batches )):
177-
178- start = time .time ()
179- if decoding_type == 'rnnt' :
180- enc_out , enc_len = asr_model .encoder .forward (audio_signal = processed_example , length = processed_example_length )
181- dec_out , dec_len = asr_model .decoding .rnnt_decoder_predictions_tensor (
182- encoder_output = enc_out , encoded_lengths = enc_len , return_hypotheses = False
183- )
184- else :
185- enc_out , enc_len , greedy_predictions = asr_model .forward (processed_signal = processed_example , processed_signal_length = processed_example_length )
186- dec_out , dec_len = asr_model .decoding .ctc_decoder_predictions_tensor (
187- enc_out , decoder_lengths = enc_len , return_hypotheses = False
188- )
189- torch .cuda .synchronize ()
190- end = time .time ()
191- if i >= warmup_batches :
192- total_time += end - start
123+ with torch .cuda .amp .autocast (dtype = torch .bfloat16 ):
124+ with torch .no_grad ():
125+ for i in tqdm (range (nbatches + warmup_batches )):
126+ start = time .time ()
127+ if decoding_type == 'rnnt' :
128+ enc_out , enc_len = asr_model .encoder .forward (audio_signal = processed_example , length = processed_signal_length )
129+ dec_out , dec_len = asr_model .decoding .rnnt_decoder_predictions_tensor (
130+ encoder_output = enc_out , encoded_lengths = enc_len , return_hypotheses = False
131+ )
132+ elif decoding_type == 'ctc' :
133+ enc_out , enc_len , greedy_predictions = asr_model .forward (input_signal = input_example , input_signal_length = input_example_length )
134+ dec_out , dec_len = asr_model .decoding .ctc_decoder_predictions_tensor (
135+ enc_out , decoder_lengths = enc_len , return_hypotheses = False
136+ )
137+ elif decoding_type == 'aed' :
138+ log_probs , encoded_len , enc_states , enc_mask = asr_model .forward (input_signal = input_example , input_signal_length = input_example_length )
139+ beam_hypotheses = asr_model .decoding .decode_predictions_tensor (
140+ encoder_hidden_states = enc_states ,
141+ encoder_input_mask = enc_mask ,
142+ decoder_input_ids = None , #torch.tensor([[ 3, 4, 8, 4, 11]]).to(DEVICE),
143+ return_hypotheses = False ,
144+ )[0 ]
145+
146+ beam_hypotheses = [asr_model .decoding .strip_special_tokens (text ) for text in beam_hypotheses ]
147+ else :
148+ raise ValueError (f'Invalid decoding type: { decoding_type } ' )
149+
150+ torch .cuda .synchronize ()
151+ end = time .time ()
152+ if i >= warmup_batches :
153+ total_time += end - start
193154
194155
195- rtf = (total_time / nbatches ) / (input_example_length / 16000 )
156+ rtf = (total_time / nbatches ) / (float ( input_example_length ) / 16000 )
196157
197158 rtfs .append (rtf )
198159
0 commit comments