Skip to content

Commit f1956ce

Browse files
authored
Merge pull request #7 from huggingface/add_remaining_nemo_models
add rtf and remaining nemo models
2 parents 68b9be9 + 5b0d7bc commit f1956ce

File tree

4 files changed

+353
-20
lines changed

4 files changed

+353
-20
lines changed

nemo_asr/calculate_rtf.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""
2+
NeMo ASR Model Profiler
3+
4+
This script performs a forward pass on an NeMo ASR models and measures its real-time factor (RTF).
5+
RTF is a metric used to evaluate the processing speed of ASR models.
6+
7+
# audio has to be a mono wav file with 16kHz sample rate
8+
9+
Parameters:
10+
--model: ASR model name or path to the model checkpoint file.
11+
--decoding_type: Type of decoding to use (ctc or rnnt).
12+
--gpu: GPU device to use.
13+
--batch_size: Batch size to use for inference.
14+
--nbatches: Total number of batches to process.
15+
--warmup_batches: Number of batches to skip as warmup.
16+
--audio: Path to the input audio file for ASR.
17+
--audio_maxlen: Maximum duration of audio to process (in seconds).
18+
--precision: Model precision (16, 32, or bf16).
19+
--cudnn_benchmark: Enable cuDNN benchmarking.
20+
--log: Enable logging.
21+
22+
Example:
23+
python calculate_rtf.py --model stt_en_conformer_ctc_large --decoding_type ctc --gpu 0 --batch_size 1 --nbatches 5 --warmup_batches 5 --audio /path/to/audio.wav --audio_maxlen 600 --precision bf16 --cudnn_benchmark
24+
"""
25+
26+
import time
27+
import argparse
28+
from tqdm import tqdm
29+
import torch
30+
from omegaconf import OmegaConf
31+
import copy
32+
import sys
33+
import soundfile as sf
34+
import numpy as np
35+
36+
from nemo.utils import logging
37+
from contextlib import nullcontext
38+
from nemo.collections.asr.models import ASRModel
39+
40+
41+
42+
parser = argparse.ArgumentParser(description='model forward pass profiler / performance tester.')
43+
parser.add_argument("--model", default='stt_en_conformer_ctc_large', type=str, help="ASR model")
44+
parser.add_argument("--decoding_type", default='ctc', type=str, help="Encoding type (bpe or char)")
45+
parser.add_argument("--gpu", default=0, type=int, help="GPU device to use")
46+
parser.add_argument("--batch_size", default=1, type=int, help="batch size to use")
47+
parser.add_argument("--nbatches", default=5, type=int, help="Total Number of batches to process")
48+
parser.add_argument("--warmup_batches", default=5, type=int, help="Number of batches to skip as warmup")
49+
parser.add_argument("--audio", default="/disk3/datasets/speech-datasets/earnings22/media/4469669.wav", type=str, help="wav file to use")
50+
parser.add_argument("--audio_maxlen", default=16, type=float, help="cut the file at given length if it is longer")
51+
parser.add_argument("--precision", default='bf16', type=str, help="precision: 16/32/bf16")
52+
parser.add_argument("--cudnn_benchmark", dest="enable_cudnn_bench", action="store_true", help="toggle cudnn benchmarking", default=True)
53+
parser.add_argument("--log", dest="log", action="store_true", help="toggle logging", default=True)
54+
55+
args = parser.parse_args()
56+
57+
if args.log:
58+
# INFO
59+
logging.setLevel(20)
60+
else:
61+
logging.setLevel(0)
62+
63+
if args.enable_cudnn_bench:
64+
torch.backends.cudnn.benchmark=True
65+
66+
PRECISION = args.precision
67+
WAV = args.audio
68+
audio_maxlen = args.audio_maxlen
69+
MODEL = args.model
70+
batch_size = args.batch_size
71+
nbatches = args.nbatches
72+
warmup_batches = args.warmup_batches
73+
decoding_type = args.decoding_type
74+
75+
DEVICE=torch.device(args.gpu)
76+
77+
if PRECISION != 'bf16' and PRECISION != '16' and PRECISION != '32':
78+
logging.error(f'unknown precision: {PRECISION}')
79+
sys.exit(1)
80+
81+
logging.info(f'precision: {PRECISION}')
82+
logging.info(f'WAV: {WAV}')
83+
logging.info(f'AUDIO MAXLEN: {audio_maxlen}')
84+
logging.info(f'MODEL: {MODEL}')
85+
logging.info(f'batch_size: {batch_size}')
86+
logging.info(f'num batches: {nbatches}')
87+
logging.info(f'cudnn_benchmark: {args.enable_cudnn_bench}')
88+
89+
90+
def get_samples(audio_file, audio_maxlen, target_sr=16000):
91+
with sf.SoundFile(audio_file, 'r') as f:
92+
dtype = 'int16'
93+
sample_rate = f.samplerate
94+
samples = f.read(dtype=dtype)
95+
if sample_rate != target_sr:
96+
samples = librosa.core.resample(samples, sample_rate, target_sr)
97+
samples = samples.astype('float32') / 32768
98+
samples = samples.transpose()
99+
sample_length = samples.shape[0]
100+
if sample_length > audio_maxlen * target_sr:
101+
logging.info(f'resizing audio sample from {sample_length / target_sr} to maxlen of {audio_maxlen}')
102+
sample_length = int(audio_maxlen * target_sr)
103+
samples = samples[:sample_length]
104+
logging.info(f'new sample lengh: {samples.shape[0]}')
105+
else:
106+
pad_length = int(audio_maxlen * target_sr) - sample_length
107+
logging.info(f'padding audio sample from {sample_length / target_sr} to maxlen of {audio_maxlen}')
108+
samples = np.pad(samples, (0, pad_length), 'constant', constant_values=(0, 0))
109+
sample_length = int(audio_maxlen * target_sr)
110+
111+
return samples, sample_length
112+
113+
def preprocess_audio(preprocessor, audio, device):
114+
audio_signal = torch.from_numpy(audio).unsqueeze_(0).to(device)
115+
116+
audio_signal_len = torch.Tensor([audio.shape[0]]).to(device)
117+
processed_signal, processed_signal_length = preprocessor(
118+
input_signal=audio_signal, length=audio_signal_len
119+
)
120+
return processed_signal, processed_signal_length
121+
122+
def extract_preprocessor(model, device):
123+
cfg = copy.deepcopy(model._cfg)
124+
OmegaConf.set_struct(cfg.preprocessor, False)
125+
cfg.preprocessor.dither = 0.0
126+
cfg.preprocessor.pad_to = 0
127+
preprocessor = model.from_config_dict(cfg.preprocessor)
128+
return preprocessor.to(device)
129+
130+
def main():
131+
132+
if MODEL.endswith('.nemo'):
133+
asr_model = ASRModel.restore_from(MODEL)
134+
else:
135+
asr_model = ASRModel.from_pretrained(MODEL)
136+
137+
asr_model.to(DEVICE)
138+
asr_model.encoder.eval()
139+
asr_model.encoder.freeze()
140+
asr_model._prepare_for_export()
141+
142+
processor = extract_preprocessor(asr_model, DEVICE)
143+
144+
input_example, input_example_length = get_samples(WAV, audio_maxlen)
145+
logging.info(f'processed example shape: {input_example.shape}')
146+
logging.info(f'processed example length shape: {input_example_length}')
147+
processed_example, processed_example_length = preprocess_audio(processor, input_example, DEVICE)
148+
processed_example = processed_example.repeat(batch_size, 1, 1)
149+
processed_example_length = processed_example_length.repeat(batch_size)
150+
logging.info(f'processed example shape: {processed_example.size()}')
151+
logging.info(f'processed example length shape: {processed_example_length.size()}')
152+
153+
profiling_context = nullcontext()
154+
# if FP16:
155+
if PRECISION == '16':
156+
precision_context = torch.cuda.amp.autocast()
157+
elif PRECISION == 'bf16':
158+
precision_context = torch.cuda.amp.autocast(dtype=torch.bfloat16)
159+
elif PRECISION == '32':
160+
pass
161+
else:
162+
logging.error(f'unknown precision: {PRECISION}')
163+
sys.exit(1)
164+
165+
166+
if decoding_type == 'ctc':
167+
asr_model.change_decoding_strategy(decoding_cfg=None)
168+
169+
logging.info(f"running {nbatches} batches; with {warmup_batches} batches warmup; batch_size: {batch_size}")
170+
rtfs=[]
171+
for i in range(3): # average over 3 runs
172+
total_time = 0
173+
with profiling_context:
174+
with precision_context:
175+
with torch.no_grad():
176+
for i in tqdm(range(nbatches + warmup_batches)):
177+
178+
start = time.time()
179+
if decoding_type == 'rnnt':
180+
enc_out, enc_len = asr_model.encoder.forward(audio_signal=processed_example, length=processed_example_length)
181+
dec_out, dec_len = asr_model.decoding.rnnt_decoder_predictions_tensor(
182+
encoder_output=enc_out, encoded_lengths=enc_len, return_hypotheses=False
183+
)
184+
else:
185+
enc_out, enc_len, greedy_predictions = asr_model.forward(processed_signal=processed_example, processed_signal_length=processed_example_length)
186+
dec_out, dec_len = asr_model.decoding.ctc_decoder_predictions_tensor(
187+
enc_out, decoder_lengths=enc_len, return_hypotheses=False
188+
)
189+
torch.cuda.synchronize()
190+
end = time.time()
191+
if i >= warmup_batches:
192+
total_time += end - start
193+
194+
195+
rtf = (total_time/nbatches) / (input_example_length / 16000)
196+
197+
rtfs.append(rtf)
198+
199+
print(f'RTF: {rtfs}')
200+
rtf = sum(rtfs)/len(rtfs)
201+
sys.stdout.write(f'{rtf:.4f}\n')
202+
203+
if __name__ == '__main__':
204+
main()

nemo_asr/run_fast_conformer_ctc.sh

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,98 @@
22

33
export PYTHONPATH="..":$PYTHONPATH
44

5-
#considering FC-XL, FC-XXL CTC models
6-
MODEL_IDs=("nvidia/stt_en_fastconformer_ctc_xlarge" "nvidia/stt_en_fastconformer_ctc_xxlarge")
5+
#considering FC-XL, FC-XXL, FC-L, C-L, C-S CTC models
6+
MODEL_IDs=("nvidia/stt_en_fastconformer_ctc_xxlarge" "nvidia/stt_en_fastconformer_ctc_xlarge" "nvidia/stt_en_fastconformer_ctc_large" "nvidia/stt_en_conformer_ctc_large" "nvidia/stt_en_conformer_ctc_small")
77
BATCH_SIZE=8
8+
DEVICE_ID=0
89

910
num_models=${#MODEL_IDs[@]}
1011

1112
for (( i=0; i<${num_models}; i++ ));
1213
do
1314
MODEL_ID=${MODEL_IDs[$i]}
1415

16+
1517
python run_eval.py \
1618
--model_id=${MODEL_ID} \
17-
--dataset_path="librispeech_asr" \
18-
--dataset="other" \
19+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
20+
--dataset="ami" \
1921
--split="test" \
20-
--device=0 \
22+
--device=${DEVICE_ID} \
2123
--batch_size=${BATCH_SIZE} \
22-
--max_eval_samples=-1
24+
--max_eval_samples=-1
25+
26+
python run_eval.py \
27+
--model_id=${MODEL_ID} \
28+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
29+
--dataset="earnings22" \
30+
--split="test" \
31+
--device=${DEVICE_ID} \
32+
--batch_size=${BATCH_SIZE} \
33+
--max_eval_samples=-1
34+
35+
python run_eval.py \
36+
--model_id=${MODEL_ID} \
37+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
38+
--dataset="gigaspeech" \
39+
--split="test" \
40+
--device=${DEVICE_ID} \
41+
--batch_size=${BATCH_SIZE} \
42+
--max_eval_samples=-1
43+
44+
python run_eval.py \
45+
--model_id=${MODEL_ID} \
46+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
47+
--dataset="librispeech" \
48+
--split="test.clean" \
49+
--device=${DEVICE_ID} \
50+
--batch_size=${BATCH_SIZE} \
51+
--max_eval_samples=-1
2352

53+
python run_eval.py \
54+
--model_id=${MODEL_ID} \
55+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
56+
--dataset="librispeech" \
57+
--split="test.other" \
58+
--device=${DEVICE_ID} \
59+
--batch_size=${BATCH_SIZE} \
60+
--max_eval_samples=-1
61+
62+
python run_eval.py \
63+
--model_id=${MODEL_ID} \
64+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
65+
--dataset="spgispeech" \
66+
--split="test" \
67+
--device=${DEVICE_ID} \
68+
--batch_size=${BATCH_SIZE} \
69+
--max_eval_samples=-1
70+
71+
python run_eval.py \
72+
--model_id=${MODEL_ID} \
73+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
74+
--dataset="tedlium" \
75+
--split="test" \
76+
--device=${DEVICE_ID} \
77+
--batch_size=${BATCH_SIZE} \
78+
--max_eval_samples=-1
79+
80+
python run_eval.py \
81+
--model_id=${MODEL_ID} \
82+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
83+
--dataset="voxpopuli" \
84+
--split="test" \
85+
--device=${DEVICE_ID} \
86+
--batch_size=${BATCH_SIZE} \
87+
--max_eval_samples=-1
2488

2589
python run_eval.py \
2690
--model_id=${MODEL_ID} \
27-
--dataset_path="librispeech_asr" \
28-
--dataset="clean" \
91+
--dataset_path="open-asr-leaderboard/datasets-test-only" \
92+
--dataset="common_voice" \
2993
--split="test" \
30-
--device=0 \
94+
--device=${DEVICE_ID} \
3195
--batch_size=${BATCH_SIZE} \
32-
--max_eval_samples=-1
96+
--max_eval_samples=-1
3397

3498
# Evaluate results
3599
RUNDIR=`pwd` && \

0 commit comments

Comments
 (0)