11import argparse
22
33import os
4- import shutil
54import torch
65import evaluate
76import soundfile
@@ -89,7 +88,6 @@ def download_audio_files(batch):
8988
9089 data_itr = iter (dataset )
9190 for data in tqdm (data_itr , desc = "Downloading Samples" ):
92- # import ipdb; ipdb.set_trace()
9391 for key in all_data :
9492 all_data [key ].append (data [key ])
9593
@@ -101,14 +99,17 @@ def download_audio_files(batch):
10199
102100
103101 total_time = 0
104- for _ in range (2 ): # warmup once and calculate rtf
102+ for _ in range (2 ): # warmup once and calculate rtf
103+ if _ == 0 :
104+ audio_files = all_data ["audio_filepaths" ][:256 ] # warmup with 4 batches
105+ else :
106+ audio_files = all_data ["audio_filepaths" ]
105107 start_time = time .time ()
106- with torch .cuda .amp .autocast (enabled = False , dtype = compute_dtype ):
107- with torch .no_grad ():
108- if 'canary' in args .model_id :
109- transcriptions = asr_model .transcribe (all_data ["audio_filepaths" ], batch_size = args .batch_size , verbose = False , pnc = 'no' , num_workers = 1 )
110- else :
111- transcriptions = asr_model .transcribe (all_data ["audio_filepaths" ], batch_size = args .batch_size , verbose = False , num_workers = 1 )
108+ with torch .cuda .amp .autocast (enabled = False , dtype = compute_dtype ), torch .inference_mode (), torch .no_grad ():
109+ if 'canary' in args .model_id :
110+ transcriptions = asr_model .transcribe (audio_files , batch_size = args .batch_size , verbose = False , pnc = 'no' , num_workers = 1 )
111+ else :
112+ transcriptions = asr_model .transcribe (audio_files , batch_size = args .batch_size , verbose = False , num_workers = 1 )
112113 end_time = time .time ()
113114 if _ == 1 :
114115 total_time += end_time - start_time
0 commit comments