Skip to content

Commit a9b2e8b

Browse files
committed
add v2 support
Signed-off-by: nithinraok <[email protected]>
1 parent b5f3da7 commit a9b2e8b

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

nemo_asr/run_eval.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,15 @@ def download_audio_files(batch):
127127
else:
128128
audio_files = all_data["audio_filepaths"]
129129
start_time = time.time()
130-
with torch.cuda.amp.autocast(enabled=False, dtype=compute_dtype), torch.inference_mode(), torch.no_grad():
130+
with torch.autocast(device_type="cuda", dtype=compute_dtype), torch.inference_mode(), torch.no_grad():
131+
132+
if 'canary' in args.model_id and 'v2' not in args.model_id:
133+
pnc = 'nopnc'
134+
else:
135+
pnc = 'pnc'
136+
131137
if 'canary' in args.model_id:
132-
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc='no', num_workers=1)
138+
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc=pnc, num_workers=1)
133139
else:
134140
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, num_workers=1)
135141
end_time = time.time()

0 commit comments

Comments
 (0)