Skip to content

Commit 83cc90f

Browse files
author
Nithin Rao Koluguri
committed
minor clean up and warmup upto 4 batches
Signed-off-by: Nithin Rao Koluguri <nithinraok>
1 parent 8320a21 commit 83cc90f

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

nemo_asr/run_eval.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22

33
import os
4-
import shutil
54
import torch
65
import evaluate
76
import soundfile
@@ -89,7 +88,6 @@ def download_audio_files(batch):
8988

9089
data_itr = iter(dataset)
9190
for data in tqdm(data_itr, desc="Downloading Samples"):
92-
# import ipdb; ipdb.set_trace()
9391
for key in all_data:
9492
all_data[key].append(data[key])
9593

@@ -101,14 +99,17 @@ def download_audio_files(batch):
10199

102100

103101
total_time = 0
104-
for _ in range(2): # warmup once and calculate rtf
102+
for _ in range(2): # warmup once and calculate rtf
103+
if _ == 0:
104+
audio_files = all_data["audio_filepaths"][:256] # warmup with 4 batches
105+
else:
106+
audio_files = all_data["audio_filepaths"]
105107
start_time = time.time()
106-
with torch.cuda.amp.autocast(enabled=False, dtype=compute_dtype):
107-
with torch.no_grad():
108-
if 'canary' in args.model_id:
109-
transcriptions = asr_model.transcribe(all_data["audio_filepaths"], batch_size=args.batch_size, verbose=False, pnc='no', num_workers=1)
110-
else:
111-
transcriptions = asr_model.transcribe(all_data["audio_filepaths"], batch_size=args.batch_size, verbose=False, num_workers=1)
108+
with torch.cuda.amp.autocast(enabled=False, dtype=compute_dtype), torch.inference_mode(), torch.no_grad():
109+
if 'canary' in args.model_id:
110+
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc='no', num_workers=1)
111+
else:
112+
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, num_workers=1)
112113
end_time = time.time()
113114
if _ == 1:
114115
total_time += end_time - start_time

0 commit comments

Comments
 (0)