Skip to content

Commit cc912b4

Browse files
authored
Add nemo aed models (#21)
* add wer aed Signed-off-by: Nithin Rao Koluguri <nithinraok> * support of nemo aed models Signed-off-by: Nithin Rao Koluguri <nithinraok> --------- Signed-off-by: Nithin Rao Koluguri <nithinraok> Co-authored-by: Nithin Rao Koluguri <nithinraok>
1 parent 8609c30 commit cc912b4

File tree

5 files changed

+70
-106
lines changed

5 files changed

+70
-106
lines changed

data/sample_4469669.wav

18.3 MB
Binary file not shown.

data/sample_ami-es2015b.wav

18.3 MB
Binary file not shown.

nemo_asr/calc_rtf.py

Lines changed: 65 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -14,58 +14,44 @@
1414
--nbatches: Total number of batches to process.
1515
--warmup_batches: Number of batches to skip as warmup.
1616
--audio: Path to the input audio file for ASR.
17-
--audio_maxlen: Maximum duration of audio to process (in seconds).
18-
--precision: Model precision (16, 32, or bf16).
19-
--cudnn_benchmark: Enable cuDNN benchmarking.
20-
--log: Enable logging.
2117
2218
Example:
23-
python calculate_rtf.py --model stt_en_conformer_ctc_large --decoding_type ctc --gpu 0 --batch_size 1 --nbatches 5 --warmup_batches 5 --audio /path/to/audio.wav --audio_maxlen 600 --precision bf16 --cudnn_benchmark
19+
python calc_rtf.py --model stt_en_conformer_ctc_large --decoding_type ctc
2420
"""
25-
21+
import copy
22+
from omegaconf import OmegaConf
2623
import time
2724
import argparse
2825
from tqdm import tqdm
2926
import torch
30-
from omegaconf import OmegaConf
31-
import copy
3227
import sys
3328
import soundfile as sf
3429
import numpy as np
30+
import librosa
3531

3632
from nemo.utils import logging
37-
from contextlib import nullcontext
3833
from nemo.collections.asr.models import ASRModel
3934

4035

4136

4237
parser = argparse.ArgumentParser(description='model forward pass profiler / performance tester.')
43-
parser.add_argument("--model", default='stt_en_conformer_ctc_large', type=str, help="ASR model")
44-
parser.add_argument("--decoding_type", default='ctc', type=str, help="Encoding type (bpe or char)")
38+
parser.add_argument("--model", default='stt_en_fastconformer_ctc_large', type=str, help="ASR model")
39+
parser.add_argument("--decoding_type", default='ctc', type=str, help="Type of model [rnnt, ctc, aed]")
4540
parser.add_argument("--gpu", default=0, type=int, help="GPU device to use")
4641
parser.add_argument("--batch_size", default=1, type=int, help="batch size to use")
47-
parser.add_argument("--nbatches", default=5, type=int, help="Total Number of batches to process")
48-
parser.add_argument("--warmup_batches", default=5, type=int, help="Number of batches to skip as warmup")
49-
parser.add_argument("--audio", default="/disk3/datasets/speech-datasets/earnings22/media/4469669.wav", type=str, help="wav file to use")
50-
parser.add_argument("--audio_maxlen", default=16, type=float, help="cut the file at given length if it is longer")
51-
parser.add_argument("--precision", default='bf16', type=str, help="precision: 16/32/bf16")
52-
parser.add_argument("--cudnn_benchmark", dest="enable_cudnn_bench", action="store_true", help="toggle cudnn benchmarking", default=True)
53-
parser.add_argument("--log", dest="log", action="store_true", help="toggle logging", default=True)
42+
parser.add_argument("--nbatches", default=3, type=int, help="Total Number of batches to process")
43+
parser.add_argument("--warmup_batches", default=3, type=int, help="Number of batches to skip as warmup")
44+
parser.add_argument("--audio", default="../data/sample_4469669.wav", type=str, help="wav file to use")
5445

55-
args = parser.parse_args()
56-
57-
if args.log:
58-
# INFO
59-
logging.setLevel(20)
60-
else:
61-
logging.setLevel(0)
46+
# parser.add_argument("--audio_maxlen", default=30, type=float, help="Multiple chunks of audio of this length is used to calculate RTFX")
6247

63-
if args.enable_cudnn_bench:
64-
torch.backends.cudnn.benchmark=True
48+
args = parser.parse_args()
49+
torch.backends.cudnn.benchmark=True
6550

66-
PRECISION = args.precision
6751
WAV = args.audio
68-
audio_maxlen = args.audio_maxlen
52+
SAMPLING_RATE = 16000
53+
chunk_len = 30
54+
total_audio_len = 600
6955
MODEL = args.model
7056
batch_size = args.batch_size
7157
nbatches = args.nbatches
@@ -74,50 +60,31 @@
7460

7561
DEVICE=torch.device(args.gpu)
7662

77-
if PRECISION != 'bf16' and PRECISION != '16' and PRECISION != '32':
78-
logging.error(f'unknown precision: {PRECISION}')
79-
sys.exit(1)
80-
81-
logging.info(f'precision: {PRECISION}')
82-
logging.info(f'WAV: {WAV}')
83-
logging.info(f'AUDIO MAXLEN: {audio_maxlen}')
8463
logging.info(f'MODEL: {MODEL}')
85-
logging.info(f'batch_size: {batch_size}')
86-
logging.info(f'num batches: {nbatches}')
87-
logging.info(f'cudnn_benchmark: {args.enable_cudnn_bench}')
88-
8964

90-
def get_samples(audio_file, audio_maxlen, target_sr=16000):
65+
def get_samples(audio_file, total_audio_len, target_sr=16000):
9166
with sf.SoundFile(audio_file, 'r') as f:
9267
dtype = 'int16'
9368
sample_rate = f.samplerate
9469
samples = f.read(dtype=dtype)
9570
if sample_rate != target_sr:
96-
samples = librosa.core.resample(samples, sample_rate, target_sr)
71+
samples = librosa.core.resample(samples, orig_sr=sample_rate, target_sr=target_sr)
9772
samples = samples.astype('float32') / 32768
9873
samples = samples.transpose()
9974
sample_length = samples.shape[0]
100-
if sample_length > audio_maxlen * target_sr:
101-
logging.info(f'resizing audio sample from {sample_length / target_sr} to maxlen of {audio_maxlen}')
102-
sample_length = int(audio_maxlen * target_sr)
75+
if sample_length > total_audio_len * target_sr:
76+
logging.info(f'resizing audio sample from {sample_length / target_sr} to maxlen of {total_audio_len}')
77+
sample_length = int(total_audio_len * target_sr)
10378
samples = samples[:sample_length]
10479
logging.info(f'new sample lengh: {samples.shape[0]}')
10580
else:
106-
pad_length = int(audio_maxlen * target_sr) - sample_length
107-
logging.info(f'padding audio sample from {sample_length / target_sr} to maxlen of {audio_maxlen}')
81+
pad_length = int(total_audio_len * target_sr) - sample_length
82+
logging.info(f'padding audio sample from {sample_length / target_sr} to maxlen of {total_audio_len}')
10883
samples = np.pad(samples, (0, pad_length), 'constant', constant_values=(0, 0))
109-
sample_length = int(audio_maxlen * target_sr)
84+
sample_length = int(total_audio_len * target_sr)
11085

11186
return samples, sample_length
11287

113-
def preprocess_audio(preprocessor, audio, device):
114-
audio_signal = torch.from_numpy(audio).unsqueeze_(0).to(device)
115-
116-
audio_signal_len = torch.Tensor([audio.shape[0]]).to(device)
117-
processed_signal, processed_signal_length = preprocessor(
118-
input_signal=audio_signal, length=audio_signal_len
119-
)
120-
return processed_signal, processed_signal_length
12188

12289
def extract_preprocessor(model, device):
12390
cfg = copy.deepcopy(model._cfg)
@@ -135,64 +102,58 @@ def main():
135102
asr_model = ASRModel.from_pretrained(MODEL)
136103

137104
asr_model.to(DEVICE)
138-
asr_model.encoder.eval()
139-
asr_model.encoder.freeze()
105+
asr_model.eval()
140106
asr_model._prepare_for_export()
141107

142-
processor = extract_preprocessor(asr_model, DEVICE)
143-
144-
input_example, input_example_length = get_samples(WAV, audio_maxlen)
145-
logging.info(f'processed example shape: {input_example.shape}')
146-
logging.info(f'processed example length shape: {input_example_length}')
147-
processed_example, processed_example_length = preprocess_audio(processor, input_example, DEVICE)
148-
processed_example = processed_example.repeat(batch_size, 1, 1)
149-
processed_example_length = processed_example_length.repeat(batch_size)
150-
logging.info(f'processed example shape: {processed_example.size()}')
151-
logging.info(f'processed example length shape: {processed_example_length.size()}')
152-
153-
profiling_context = nullcontext()
154-
# if FP16:
155-
if PRECISION == '16':
156-
precision_context = torch.cuda.amp.autocast()
157-
elif PRECISION == 'bf16':
158-
precision_context = torch.cuda.amp.autocast(dtype=torch.bfloat16)
159-
elif PRECISION == '32':
160-
pass
161-
else:
162-
logging.error(f'unknown precision: {PRECISION}')
163-
sys.exit(1)
164-
108+
preprocessor = extract_preprocessor(asr_model, DEVICE)
109+
input_example, input_example_length = get_samples(WAV, total_audio_len)
110+
input_example = torch.tensor(input_example).to(DEVICE)
111+
input_example = input_example.repeat(batch_size, 1)
112+
input_example_length = torch.tensor(input_example_length).to(DEVICE)
113+
input_example_length = input_example_length.repeat(batch_size)
165114

166-
if decoding_type == 'ctc':
167-
asr_model.change_decoding_strategy(decoding_cfg=None)
115+
processed_signal, processed_signal_length = preprocessor(input_signal=input_example, length=input_example_length)
116+
processed_example = processed_signal.repeat(batch_size, 1, 1)
168117

118+
169119
logging.info(f"running {nbatches} batches; with {warmup_batches} batches warmup; batch_size: {batch_size}")
170120
rtfs=[]
171121
for i in range(3): # average over 3 runs
172122
total_time = 0
173-
with profiling_context:
174-
with precision_context:
175-
with torch.no_grad():
176-
for i in tqdm(range(nbatches + warmup_batches)):
177-
178-
start = time.time()
179-
if decoding_type == 'rnnt':
180-
enc_out, enc_len = asr_model.encoder.forward(audio_signal=processed_example, length=processed_example_length)
181-
dec_out, dec_len = asr_model.decoding.rnnt_decoder_predictions_tensor(
182-
encoder_output=enc_out, encoded_lengths=enc_len, return_hypotheses=False
183-
)
184-
else:
185-
enc_out, enc_len, greedy_predictions = asr_model.forward(processed_signal=processed_example, processed_signal_length=processed_example_length)
186-
dec_out, dec_len = asr_model.decoding.ctc_decoder_predictions_tensor(
187-
enc_out, decoder_lengths=enc_len, return_hypotheses=False
188-
)
189-
torch.cuda.synchronize()
190-
end = time.time()
191-
if i >= warmup_batches:
192-
total_time += end - start
123+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
124+
with torch.no_grad():
125+
for i in tqdm(range(nbatches + warmup_batches)):
126+
start = time.time()
127+
if decoding_type == 'rnnt':
128+
enc_out, enc_len = asr_model.encoder.forward(audio_signal=processed_example, length=processed_signal_length)
129+
dec_out, dec_len = asr_model.decoding.rnnt_decoder_predictions_tensor(
130+
encoder_output=enc_out, encoded_lengths=enc_len, return_hypotheses=False
131+
)
132+
elif decoding_type == 'ctc':
133+
enc_out, enc_len, greedy_predictions = asr_model.forward(input_signal=input_example, input_signal_length=input_example_length)
134+
dec_out, dec_len = asr_model.decoding.ctc_decoder_predictions_tensor(
135+
enc_out, decoder_lengths=enc_len, return_hypotheses=False
136+
)
137+
elif decoding_type == 'aed':
138+
log_probs, encoded_len, enc_states, enc_mask = asr_model.forward(input_signal=input_example, input_signal_length=input_example_length)
139+
beam_hypotheses = asr_model.decoding.decode_predictions_tensor(
140+
encoder_hidden_states=enc_states,
141+
encoder_input_mask=enc_mask,
142+
decoder_input_ids=None, #torch.tensor([[ 3, 4, 8, 4, 11]]).to(DEVICE),
143+
return_hypotheses=False,
144+
)[0]
145+
146+
beam_hypotheses = [asr_model.decoding.strip_special_tokens(text) for text in beam_hypotheses]
147+
else:
148+
raise ValueError(f'Invalid decoding type: {decoding_type}')
149+
150+
torch.cuda.synchronize()
151+
end = time.time()
152+
if i >= warmup_batches:
153+
total_time += end - start
193154

194155

195-
rtf = (total_time/nbatches) / (input_example_length / 16000)
156+
rtf = (total_time/nbatches) / (float(input_example_length) / 16000)
196157

197158
rtfs.append(rtf)
198159

nemo_asr/run_eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ def main(args):
9292
else:
9393
device = torch.device("cpu")
9494

95-
asr_model = ASRModel.from_pretrained(args.model_id, map_location=device) # type: ASRModel
95+
if args.model_id.endswith(".nemo"):
96+
asr_model = ASRModel.restore_from(args.model_id, map_location=device)
97+
else:
98+
asr_model = ASRModel.from_pretrained(args.model_id, map_location=device) # type: ASRModel
9699
asr_model.freeze()
97100

98101
dataset = data_utils.load_data(args)

requirements/requirements_nemo.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
nemo_toolkit[all]
1+
git+https://github.com/NVIDIA/[email protected]#egg=nemo_toolkit[all]
22
tqdm
33
soundfile
44
librosa

0 commit comments

Comments
 (0)