|
| 1 | +import argparse |
| 2 | +import datasets |
| 3 | +import evaluate |
| 4 | +import soundfile as sf |
| 5 | +import tempfile |
| 6 | +import time |
| 7 | +import os |
| 8 | +import requests |
| 9 | +from tqdm import tqdm |
| 10 | +from dotenv import load_dotenv |
| 11 | +from io import BytesIO |
| 12 | +import assemblyai as aai |
| 13 | +import openai |
| 14 | +from elevenlabs.client import ElevenLabs |
| 15 | +from rev_ai import apiclient |
| 16 | +from rev_ai.models import CustomVocabulary, CustomerUrlData |
| 17 | +from normalizer import data_utils |
| 18 | +import concurrent.futures |
| 19 | + |
| 20 | +load_dotenv() |
| 21 | + |
| 22 | +def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20): |
| 23 | + API_URL = "https://datasets-server.huggingface.co/rows" |
| 24 | + |
| 25 | + size_url = f"https://datasets-server.huggingface.co/size?dataset={dataset_path}&config={dataset}&split={split}" |
| 26 | + size_response = requests.get(size_url).json() |
| 27 | + total_rows = size_response['size']['config']['num_rows'] |
| 28 | + audio_urls = [] |
| 29 | + for offset in tqdm(range(0, total_rows, batch_size), desc="Fetching audio URLs"): |
| 30 | + params = { |
| 31 | + "dataset": dataset_path, |
| 32 | + "config": dataset, |
| 33 | + "split": split, |
| 34 | + "offset": offset, |
| 35 | + "length": min(batch_size, total_rows - offset) |
| 36 | + } |
| 37 | + |
| 38 | + retries = 0 |
| 39 | + while retries <= max_retries: |
| 40 | + try: |
| 41 | + response = requests.get(API_URL, params=params) |
| 42 | + response.raise_for_status() |
| 43 | + data = response.json() |
| 44 | + audio_urls.extend(data['rows']) |
| 45 | + break |
| 46 | + except (requests.exceptions.RequestException, ValueError) as e: |
| 47 | + retries += 1 |
| 48 | + print(f"Error fetching data: {e}, retrying ({retries}/{max_retries})...") |
| 49 | + time.sleep(10) |
| 50 | + if retries >= max_retries: |
| 51 | + raise Exception("Max retries exceeded while fetching data.") |
| 52 | + time.sleep(1) |
| 53 | + return audio_urls |
| 54 | + |
| 55 | +def transcribe_with_retry(model_name, audio_file_path, sample, max_retries=10, use_url=False): |
| 56 | + retries = 0 |
| 57 | + while retries <= max_retries: |
| 58 | + try: |
| 59 | + if model_name.startswith("assembly/"): |
| 60 | + aai.settings.api_key = os.getenv("ASSEMBLYAI_API_KEY") |
| 61 | + transcriber = aai.Transcriber() |
| 62 | + config = aai.TranscriptionConfig( |
| 63 | + speech_model=model_name.split("/")[1], |
| 64 | + language_code="en", |
| 65 | + ) |
| 66 | + if use_url: |
| 67 | + audio_url = sample['row']['audio'][0]['src'] |
| 68 | + audio_duration = sample['row']['audio_length_s'] |
| 69 | + if audio_duration < 0.160: |
| 70 | + print(f"Skipping audio duration {audio_duration}s") |
| 71 | + return "." |
| 72 | + transcript = transcriber.transcribe(audio_url, config=config) |
| 73 | + else: |
| 74 | + audio_duration = len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"] |
| 75 | + if audio_duration < 0.160: |
| 76 | + print(f"Skipping audio duration {audio_duration}s") |
| 77 | + return "." |
| 78 | + transcript = transcriber.transcribe(audio_file_path, config=config) |
| 79 | + |
| 80 | + if transcript.status == aai.TranscriptStatus.error: |
| 81 | + raise Exception(f"AssemblyAI transcription error: {transcript.error}") |
| 82 | + return transcript.text |
| 83 | + |
| 84 | + elif model_name.startswith("openai/"): |
| 85 | + if use_url: |
| 86 | + response = requests.get(sample['row']['audio'][0]['src']) |
| 87 | + audio_data = BytesIO(response.content) |
| 88 | + response = openai.Audio.transcribe( |
| 89 | + model=model_name.split("/")[1], |
| 90 | + file=audio_data, |
| 91 | + response_format="text", |
| 92 | + language="en", |
| 93 | + temperature=0.0, |
| 94 | + ) |
| 95 | + else: |
| 96 | + with open(audio_file_path, "rb") as audio_file: |
| 97 | + response = openai.Audio.transcribe( |
| 98 | + model=model_name.split("/")[1], |
| 99 | + file=audio_file, |
| 100 | + response_format="text", |
| 101 | + language="en", |
| 102 | + temperature=0.0, |
| 103 | + ) |
| 104 | + return response.strip() |
| 105 | + |
| 106 | + elif model_name.startswith("elevenlabs/"): |
| 107 | + client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY")) |
| 108 | + if use_url: |
| 109 | + response = requests.get(sample['row']['audio'][0]['src']) |
| 110 | + audio_data = BytesIO(response.content) |
| 111 | + transcription = client.speech_to_text.convert( |
| 112 | + file=audio_data, |
| 113 | + model_id=model_name.split("/")[1], |
| 114 | + language_code="eng", |
| 115 | + tag_audio_events=True, |
| 116 | + |
| 117 | + ) |
| 118 | + else: |
| 119 | + with open(audio_file_path, "rb") as audio_file: |
| 120 | + transcription = client.speech_to_text.convert( |
| 121 | + file=audio_file, |
| 122 | + model_id=model_name.split("/")[1], |
| 123 | + language_code="eng", |
| 124 | + tag_audio_events=True, |
| 125 | + ) |
| 126 | + return transcription.text |
| 127 | + |
| 128 | + elif model_name.startswith("revai/"): |
| 129 | + access_token = os.getenv("REVAI_API_KEY") |
| 130 | + client = apiclient.RevAiAPIClient(access_token) |
| 131 | + |
| 132 | + if use_url: |
| 133 | + # Submit job with URL for Rev.ai |
| 134 | + job = client.submit_job_url( |
| 135 | + transcriber=model_name.split("/")[1], |
| 136 | + source_config=CustomerUrlData(sample['row']['audio'][0]['src']), |
| 137 | + metadata="benchmarking_job", |
| 138 | + ) |
| 139 | + else: |
| 140 | + # Submit job with local file |
| 141 | + job = client.submit_job_local_file( |
| 142 | + transcriber=model_name.split("/")[1], |
| 143 | + filename=audio_file_path, |
| 144 | + metadata="benchmarking_job", |
| 145 | + ) |
| 146 | + |
| 147 | + # Polling until job is done |
| 148 | + while True: |
| 149 | + job_details = client.get_job_details(job.id) |
| 150 | + if job_details.status.name in ["IN_PROGRESS", "TRANSCRIBING"]: |
| 151 | + time.sleep(0.1) |
| 152 | + continue |
| 153 | + elif job_details.status.name == "FAILED": |
| 154 | + raise Exception("RevAI transcription failed.") |
| 155 | + elif job_details.status.name == "TRANSCRIBED": |
| 156 | + break |
| 157 | + |
| 158 | + transcript_object = client.get_transcript_object(job.id) |
| 159 | + |
| 160 | + # Combine all words from all monologues |
| 161 | + transcript_text = [] |
| 162 | + for monologue in transcript_object.monologues: |
| 163 | + for element in monologue.elements: |
| 164 | + transcript_text.append(element.value) |
| 165 | + |
| 166 | + return "".join(transcript_text) if transcript_text else "" |
| 167 | + |
| 168 | + else: |
| 169 | + raise ValueError("Invalid model prefix, must start with 'assembly/', 'openai/', 'elevenlabs/' or 'revai/'") |
| 170 | + |
| 171 | + except Exception as e: |
| 172 | + retries += 1 |
| 173 | + if retries > max_retries: |
| 174 | + return "." |
| 175 | + |
| 176 | + if not use_url: |
| 177 | + sf.write(audio_file_path, sample["audio"]["array"], sample["audio"]["sampling_rate"], format="WAV") |
| 178 | + delay = 1 |
| 179 | + print(f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})") |
| 180 | + time.sleep(delay) |
| 181 | + |
| 182 | + |
| 183 | +def transcribe_dataset(dataset_path, dataset, split, model_name, use_url=False, max_samples=None, max_workers=4): |
| 184 | + if use_url: |
| 185 | + audio_rows = fetch_audio_urls(dataset_path, dataset, split) |
| 186 | + if max_samples: |
| 187 | + audio_rows = audio_rows[:max_samples] |
| 188 | + ds = audio_rows |
| 189 | + else: |
| 190 | + ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False) |
| 191 | + ds = data_utils.prepare_data(ds) |
| 192 | + if max_samples: |
| 193 | + ds = ds.take(max_samples) |
| 194 | + |
| 195 | + results = {"references": [], "predictions": [], "audio_length_s": [], "transcription_time_s": []} |
| 196 | + |
| 197 | + print(f"Transcribing with model: {model_name}") |
| 198 | + |
| 199 | + def process_sample(sample): |
| 200 | + if use_url: |
| 201 | + reference = sample['row']['text'].strip() or " " |
| 202 | + audio_duration = sample['row']['audio_length_s'] |
| 203 | + start = time.time() |
| 204 | + try: |
| 205 | + transcription = transcribe_with_retry(model_name, None, sample, use_url=True) |
| 206 | + except Exception as e: |
| 207 | + print(f"Failed to transcribe after retries: {e}") |
| 208 | + return None |
| 209 | + |
| 210 | + else: |
| 211 | + reference = sample.get("norm_text", "").strip() or " " |
| 212 | + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: |
| 213 | + sf.write(tmpfile.name, sample["audio"]["array"], sample["audio"]["sampling_rate"], format="WAV") |
| 214 | + tmp_path = tmpfile.name |
| 215 | + audio_duration = len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"] |
| 216 | + |
| 217 | + start = time.time() |
| 218 | + try: |
| 219 | + transcription = transcribe_with_retry(model_name, tmp_path, sample, use_url=False) |
| 220 | + except Exception as e: |
| 221 | + print(f"Failed to transcribe after retries: {e}") |
| 222 | + os.unlink(tmp_path) |
| 223 | + return None |
| 224 | + finally: |
| 225 | + if os.path.exists(tmp_path): |
| 226 | + os.unlink(tmp_path) |
| 227 | + else: |
| 228 | + print(f"File {tmp_path} does not exist") |
| 229 | + |
| 230 | + transcription_time = time.time() - start |
| 231 | + return reference, transcription, audio_duration, transcription_time |
| 232 | + |
| 233 | + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
| 234 | + future_to_sample = {executor.submit(process_sample, sample): sample for sample in ds} |
| 235 | + for future in tqdm(concurrent.futures.as_completed(future_to_sample), total=len(future_to_sample), desc="Transcribing"): |
| 236 | + result = future.result() |
| 237 | + if result: |
| 238 | + reference, transcription, audio_duration, transcription_time = result |
| 239 | + results["predictions"].append(transcription) |
| 240 | + results["references"].append(reference) |
| 241 | + results["audio_length_s"].append(audio_duration) |
| 242 | + results["transcription_time_s"].append(transcription_time) |
| 243 | + |
| 244 | + results["predictions"] = [data_utils.normalizer(transcription) or " " for transcription in results["predictions"]] |
| 245 | + results["references"] = [data_utils.normalizer(reference) or " " for reference in results["references"]] |
| 246 | + |
| 247 | + manifest_path = data_utils.write_manifest( |
| 248 | + results["references"], |
| 249 | + results["predictions"], |
| 250 | + model_name.replace("/", "-"), |
| 251 | + dataset_path, |
| 252 | + dataset, |
| 253 | + split, |
| 254 | + audio_length=results["audio_length_s"], |
| 255 | + transcription_time=results["transcription_time_s"], |
| 256 | + ) |
| 257 | + |
| 258 | + print("Results saved at path:", manifest_path) |
| 259 | + |
| 260 | + wer_metric = evaluate.load("wer") |
| 261 | + wer = wer_metric.compute(references=results["references"], predictions=results["predictions"]) |
| 262 | + wer_percent = round(100 * wer, 2) |
| 263 | + rtfx = round(sum(results["audio_length_s"]) / sum(results["transcription_time_s"]), 2) |
| 264 | + |
| 265 | + print("WER:", wer_percent, "%") |
| 266 | + print("RTFx:", rtfx) |
| 267 | + |
| 268 | + |
| 269 | +if __name__ == "__main__": |
| 270 | + parser = argparse.ArgumentParser(description="Unified Transcription Script with Concurrency") |
| 271 | + parser.add_argument("--dataset_path", required=True) |
| 272 | + parser.add_argument("--dataset", required=True) |
| 273 | + parser.add_argument("--split", default="test") |
| 274 | + parser.add_argument("--model_name", required=True, help="Prefix model name with 'assembly/', 'openai/', or 'elevenlabs/'") |
| 275 | + parser.add_argument("--max_samples", type=int, default=None) |
| 276 | + parser.add_argument("--max_workers", type=int, default=300, help="Number of concurrent threads") |
| 277 | + parser.add_argument("--use_url", action="store_true", help="Use URL-based audio fetching instead of datasets") |
| 278 | + |
| 279 | + args = parser.parse_args() |
| 280 | + |
| 281 | + transcribe_dataset( |
| 282 | + dataset_path=args.dataset_path, |
| 283 | + dataset=args.dataset, |
| 284 | + split=args.split, |
| 285 | + model_name=args.model_name, |
| 286 | + use_url=args.use_url, |
| 287 | + max_samples=args.max_samples, |
| 288 | + max_workers=args.max_workers, |
| 289 | + ) |
0 commit comments