11import argparse
22import datasets
33import evaluate
4- import io
5- import json
64import soundfile as sf
75import tempfile
86import time
7+ import os
8+ import requests
99from tqdm import tqdm
10+ from dotenv import load_dotenv
11+ from io import BytesIO
12+ import assemblyai as aai
1013import openai
11- from normalizer import data_utils # must provide .normalizer() and .write_manifest()
14+ from elevenlabs .client import ElevenLabs
15+ from rev_ai import apiclient
16+ from rev_ai .models import CustomVocabulary , CustomerUrlData
17+ from normalizer import data_utils
18+ import concurrent .futures
1219
13- def transcribe_dataset (
14- dataset_path , dataset , split ,
15- model_name = "whisper-1" ,
16- ):
17- # Load dataset
20+ load_dotenv ()
21+
22+ def transcribe_with_retry (model_name , audio_file_path , sample , max_retries = 10 ):
23+ retries = 0
24+ while retries <= max_retries :
25+ try :
26+ if model_name .startswith ("assembly/" ):
27+ aai .settings .api_key = os .getenv ("ASSEMBLYAI_API_KEY" )
28+ transcriber = aai .Transcriber ()
29+ config = aai .TranscriptionConfig (
30+ speech_model = model_name .split ("/" )[1 ],
31+ language_code = "en" ,
32+ )
33+ audio_duration = len (sample ["audio" ]["array" ]) / sample ["audio" ]["sampling_rate" ]
34+ if audio_duration < 0.160 :
35+ print (f"Skipping audio duration { audio_duration } s" )
36+ return "."
37+ transcript = transcriber .transcribe (audio_file_path , config = config )
38+ if transcript .status == aai .TranscriptStatus .error :
39+ raise Exception (f"AssemblyAI transcription error: { transcript .error } " )
40+ return transcript .text
41+
42+ elif model_name .startswith ("openai/" ):
43+ with open (audio_file_path , "rb" ) as audio_file :
44+ response = openai .Audio .transcribe (
45+ model = model_name .split ("/" )[1 ],
46+ file = audio_file ,
47+ response_format = "text" ,
48+ language = "en" ,
49+ temperature = 0.0 ,
50+ )
51+ return response .strip ()
52+
53+ elif model_name .startswith ("elevenlabs/" ):
54+ client = ElevenLabs (api_key = os .getenv ("ELEVENLABS_API_KEY" ))
55+ with open (audio_file_path , "rb" ) as audio_file :
56+ transcription = client .speech_to_text .convert (
57+ file = audio_file ,
58+ model_id = model_name .split ("/" )[1 ],
59+ language_code = "eng" ,
60+ )
61+ return transcription .text
62+
63+ elif model_name .startswith ("revai/" ):
64+ access_token = os .getenv ("REVAI_API_KEY" )
65+ client = apiclient .RevAiAPIClient (access_token )
66+
67+ # Submit job with local file
68+ job = client .submit_job_local_file (
69+ transcriber = model_name .split ("/" )[1 ],
70+ filename = audio_file_path ,
71+ metadata = "benchmarking_job" ,
72+ remove_disfluencies = True ,
73+ remove_atmospherics = True ,
74+ )
75+
76+ # Polling until job is done
77+ while True :
78+ job_details = client .get_job_details (job .id )
79+ if job_details .status .name in ["IN_PROGRESS" , "TRANSCRIBING" ]:
80+ time .sleep (0.1 )
81+ continue
82+ elif job_details .status .name == "FAILED" :
83+ raise Exception ("RevAI transcription failed." )
84+ elif job_details .status .name == "TRANSCRIBED" :
85+ break
86+
87+ transcript_object = client .get_transcript_object (job .id )
88+
89+ # Combine all words from all monologues
90+ transcript_text = []
91+ for monologue in transcript_object .monologues :
92+ for element in monologue .elements :
93+ transcript_text .append (element .value )
94+
95+ return "" .join (transcript_text ) if transcript_text else ""
96+
97+ else :
98+ raise ValueError ("Invalid model prefix, must start with 'assembly/', 'openai/', or 'elevenlabs/'" )
99+
100+ except Exception as e :
101+ retries += 1
102+ if retries > max_retries :
103+ return "."
104+
105+ sf .write (audio_file_path , sample ["audio" ]["array" ], sample ["audio" ]["sampling_rate" ], format = "WAV" )
106+ delay = 1
107+ print (f"API Error: { str (e )} . Retrying in { delay } s... (Attempt { retries } /{ max_retries } )" )
108+ time .sleep (delay )
109+
110+
111+ def transcribe_dataset (dataset_path , dataset , split , model_name , max_samples = None , max_workers = 4 ):
18112 ds = datasets .load_dataset (dataset_path , dataset , split = split , streaming = False )
113+ ds = data_utils .prepare_data (ds )
114+ if max_samples :
115+ ds = ds .take (max_samples )
116+
117+ results = {"references" : [], "predictions" : [], "audio_length_s" : [], "transcription_time_s" : []}
118+
119+ print (f"Transcribing with model: { model_name } " )
19120
20- # Track results
21- all_results = {
22- "references" : [],
23- "predictions" : [],
24- "audio_length_s" : [],
25- "transcription_time_s" : [],
26- }
27-
28- print (f"Transcribing with OpenAI model: { model_name } " )
29-
30- for i , sample in tqdm (enumerate (ds ), total = len (ds ), desc = "Transcribing" ):
31- # Get reference text, use empty string if not present
32- reference = sample .get ("text" , "" ).strip ()
33-
34- # Write temp .wav file
35- with tempfile .NamedTemporaryFile (suffix = ".wav" ) as tmpfile :
121+ def process_sample (sample ):
122+ reference = sample .get ("norm_text" , "" ).strip () or " "
123+ with tempfile .NamedTemporaryFile (suffix = ".wav" , delete = False ) as tmpfile :
36124 sf .write (tmpfile .name , sample ["audio" ]["array" ], sample ["audio" ]["sampling_rate" ], format = "WAV" )
125+ tmp_path = tmpfile .name
126+
127+ start = time .time ()
128+ try :
129+ transcription = transcribe_with_retry (model_name , tmp_path , sample )
130+ except Exception as e :
131+ print (f"Failed to transcribe after retries: { e } " )
132+ os .unlink (tmp_path )
133+ return None
134+ finally :
135+ if os .path .exists (tmp_path ):
136+ os .unlink (tmp_path )
137+ else :
138+ print (f"File { tmp_path } does not exist" )
139+
140+ transcription_time = time .time () - start
141+ audio_duration = len (sample ["audio" ]["array" ]) / sample ["audio" ]["sampling_rate" ]
142+ transcription = data_utils .normalizer (transcription ) or " "
143+ return reference , transcription , audio_duration , transcription_time
144+
145+ with concurrent .futures .ThreadPoolExecutor (max_workers = max_workers ) as executor :
146+ future_to_sample = {executor .submit (process_sample , sample ): sample for sample in ds }
147+ for future in tqdm (concurrent .futures .as_completed (future_to_sample ), total = len (future_to_sample ), desc = "Transcribing" ):
148+ result = future .result ()
149+ if result :
150+ reference , transcription , audio_duration , transcription_time = result
151+ results ["predictions" ].append (transcription )
152+ results ["references" ].append (reference )
153+ results ["audio_length_s" ].append (audio_duration )
154+ results ["transcription_time_s" ].append (transcription_time )
37155
38- start = time .time ()
39- response = openai .Audio .transcribe (
40- model = model_name ,
41- file = tmpfile ,
42- response_format = "text"
43- )
44- end = time .time ()
45-
46- transcription = response .strip ()
47- reference = sample ["text" ]
48- audio_duration = sample ["audio_length_s" ]
49- transcription_time = end - start
50-
51- transcription = data_utils .normalizer (transcription )
52- reference = data_utils .normalizer (reference )
53- # Store
54- all_results ["predictions" ].append (transcription )
55- all_results ["references" ].append (reference )
56- all_results ["audio_length_s" ].append (audio_duration )
57- all_results ["transcription_time_s" ].append (transcription_time )
58-
59- # Save results to manifest
60156 manifest_path = data_utils .write_manifest (
61- all_results ["references" ],
62- all_results ["predictions" ],
63- model_name ,
157+ results ["references" ],
158+ results ["predictions" ],
159+ model_name . replace ( "/" , "-" ) ,
64160 dataset_path ,
65161 dataset ,
66162 split ,
67- audio_length = all_results ["audio_length_s" ],
68- transcription_time = all_results ["transcription_time_s" ],
163+ audio_length = results ["audio_length_s" ],
164+ transcription_time = results ["transcription_time_s" ],
69165 )
166+
70167 print ("Results saved at path:" , manifest_path )
71168
72- # Evaluate
73169 wer_metric = evaluate .load ("wer" )
74- wer = wer_metric .compute (
75- references = all_results ["references" ],
76- predictions = all_results ["predictions" ]
77- )
78- wer = round (100 * wer , 2 )
79- rtfx = round (
80- sum (all_results ["audio_length_s" ]) / sum (all_results ["transcription_time_s" ]),
81- 2
82- )
170+ wer = wer_metric .compute (references = results ["references" ], predictions = results ["predictions" ])
171+ wer_percent = round (100 * wer , 2 )
172+ rtfx = round (sum (results ["audio_length_s" ]) / sum (results ["transcription_time_s" ]), 2 )
83173
84- print ("WER:" , wer , "%" , "RTFx:" , rtfx )
174+ print ("WER:" , wer_percent , "%" )
175+ print ("RTFx:" , rtfx )
85176
86- if __name__ == "__main__" :
87- parser = argparse .ArgumentParser (description = "Transcribe using OpenAI Whisper API" )
88177
89- parser .add_argument ("--dataset_path" , required = True , help = "Dataset path or name" )
90- parser .add_argument ("--dataset" , required = True , help = "Subset name of the dataset" )
91- parser .add_argument ("--split" , default = "test" , help = "Dataset split" )
92- parser .add_argument ("--model_name" , default = "whisper-1" , help = "OpenAI model name" )
178+ if __name__ == "__main__" :
179+ parser = argparse .ArgumentParser (description = "Unified Transcription Script with Concurrency" )
180+ parser .add_argument ("--dataset_path" , required = True )
181+ parser .add_argument ("--dataset" , required = True )
182+ parser .add_argument ("--split" , default = "test" )
183+ parser .add_argument ("--model_name" , required = True , help = "Prefix model name with 'assembly/', 'openai/', or 'elevenlabs/'" )
184+ parser .add_argument ("--max_samples" , type = int , default = None )
185+ parser .add_argument ("--max_workers" , type = int , default = 50 , help = "Number of concurrent threads" )
93186
94187 args = parser .parse_args ()
95188
@@ -98,4 +191,6 @@ def transcribe_dataset(
98191 dataset = args .dataset ,
99192 split = args .split ,
100193 model_name = args .model_name ,
194+ max_samples = args .max_samples ,
195+ max_workers = args .max_workers ,
101196 )
0 commit comments