11import argparse
22from typing import Optional
3+ import json
4+ import hashlib
5+ from pathlib import Path
36import datasets
47import evaluate
58import soundfile as sf
2629load_dotenv ()
2730
2831
32+ def get_cache_path (model_name , dataset_path , dataset , split ):
33+ cache_dir = Path (".cache/transcriptions" )
34+ cache_dir .mkdir (parents = True , exist_ok = True )
35+
36+ cache_key = f"{ model_name } _{ dataset_path } _{ dataset } _{ split } " .replace ("/" , "_" ).replace (":" , "_" )
37+ return cache_dir / f"{ cache_key } .jsonl"
38+
39+
40+ def load_cache (cache_path ):
41+ cached_results = {}
42+ if cache_path .exists ():
43+ try :
44+ with open (cache_path , "r" ) as f :
45+ for line in f :
46+ if line .strip ():
47+ entry = json .loads (line )
48+ cached_results [entry ["sample_id" ]] = entry
49+ print (f"Loaded { len (cached_results )} cached results from { cache_path } " )
50+ except Exception as e :
51+ print (f"Warning: Error loading cache: { e } " )
52+ return cached_results
53+
54+
55+ def save_to_cache (cache_path , sample_id , reference , prediction , audio_duration , transcription_time ):
56+ entry = {
57+ "sample_id" : sample_id ,
58+ "reference" : reference ,
59+ "prediction" : prediction ,
60+ "audio_duration" : audio_duration ,
61+ "transcription_time" : transcription_time
62+ }
63+
64+ with open (cache_path , "a" ) as f :
65+ f .write (json .dumps (entry ) + "\n " )
66+
67+
68+ def get_sample_id (sample , index , use_url ):
69+ """Generate a unique ID for a sample based on its content."""
70+ if use_url :
71+ id_str = f"{ index } _{ sample ['row' ]['audio' ][0 ]['src' ]} "
72+ else :
73+ # Use the text content for better uniqueness
74+ text = sample .get ('norm_text' , sample .get ('text' , '' ))
75+ audio_len = len (sample .get ('audio' , {}).get ('array' , [])) if 'audio' in sample else index
76+ id_str = f"{ index } _{ text [:50 ]} _{ audio_len } "
77+
78+ return hashlib .md5 (id_str .encode ()).hexdigest ()
79+
80+
2981def fetch_audio_urls (dataset_path , dataset , split , batch_size = 100 , max_retries = 20 ):
3082 API_URL = "https://datasets-server.huggingface.co/rows"
3183
@@ -256,7 +308,6 @@ def transcribe_with_retry(
256308 transcript_text .append (element .value )
257309
258310 return "" .join (transcript_text ) if transcript_text else ""
259-
260311 else :
261312 raise ValueError (
262313 "Invalid model prefix, must start with 'assembly/', 'openai/', 'elevenlabs/' or 'revai/'"
@@ -289,17 +340,28 @@ def transcribe_dataset(
289340 use_url = False ,
290341 max_samples = None ,
291342 max_workers = 4 ,
343+ clear_cache = False ,
292344):
345+ cache_path = get_cache_path (model_name , dataset_path , dataset , split )
346+
347+ if clear_cache and cache_path .exists ():
348+ print (f"Clearing cache file: { cache_path } " )
349+ cache_path .unlink ()
350+
351+ cached_results = load_cache (cache_path )
352+ print (f"Cache file: { cache_path } " )
353+
293354 if use_url :
294355 audio_rows = fetch_audio_urls (dataset_path , dataset , split )
295356 if max_samples :
296357 audio_rows = itertools .islice (audio_rows , max_samples )
297- ds = audio_rows
358+ ds = list ( enumerate ( audio_rows ))
298359 else :
299360 ds = datasets .load_dataset (dataset_path , dataset , split = split , streaming = False )
300361 ds = data_utils .prepare_data (ds )
301362 if max_samples :
302363 ds = ds .take (max_samples )
364+ ds = list (enumerate (ds ))
303365
304366 results = {
305367 "references" : [],
@@ -310,7 +372,15 @@ def transcribe_dataset(
310372
311373 print (f"Transcribing with model: { model_name } " )
312374
313- def process_sample (sample ):
375+ def process_sample (idx_sample ):
376+ index , sample = idx_sample
377+ sample_id = get_sample_id (sample , index , use_url )
378+
379+ # Check if already cached
380+ if sample_id in cached_results :
381+ cached = cached_results [sample_id ]
382+ return cached ["reference" ], cached ["prediction" ], cached ["audio_duration" ], cached ["transcription_time" ]
383+
314384 if use_url :
315385 reference = sample ["row" ]["text" ].strip () or " "
316386 audio_duration = sample ["row" ]["audio_length_s" ]
@@ -353,8 +423,19 @@ def process_sample(sample):
353423 print (f"File { tmp_path } does not exist" )
354424
355425 transcription_time = time .time () - start
356- return reference , transcription , audio_duration , transcription_time
357-
426+
427+ # Normalize before caching
428+ normalized_reference = data_utils .normalizer (reference ) or " "
429+ normalized_prediction = data_utils .normalizer (transcription ) or " "
430+
431+ # Save to cache
432+ save_to_cache (cache_path , sample_id , normalized_reference , normalized_prediction , audio_duration , transcription_time )
433+
434+ return normalized_reference , normalized_prediction , audio_duration , transcription_time
435+
436+ cached_count = sum (1 for idx , sample in ds if get_sample_id (sample , idx , use_url ) in cached_results )
437+ print (f"Skipping { cached_count } cached samples, processing { len (ds ) - cached_count } new samples" )
438+
358439 with concurrent .futures .ThreadPoolExecutor (max_workers = max_workers ) as executor :
359440 future_to_sample = {
360441 executor .submit (process_sample , sample ): sample for sample in ds
@@ -371,15 +452,13 @@ def process_sample(sample):
371452 results ["references" ].append (reference )
372453 results ["audio_length_s" ].append (audio_duration )
373454 results ["transcription_time_s" ].append (transcription_time )
374-
375455 results ["predictions" ] = [
376456 data_utils .normalizer (transcription ) or " "
377457 for transcription in results ["predictions" ]
378458 ]
379459 results ["references" ] = [
380460 data_utils .normalizer (reference ) or " " for reference in results ["references" ]
381461 ]
382-
383462 manifest_path = data_utils .write_manifest (
384463 results ["references" ],
385464 results ["predictions" ],
@@ -420,13 +499,18 @@ def process_sample(sample):
420499 )
421500 parser .add_argument ("--max_samples" , type = int , default = None )
422501 parser .add_argument (
423- "--max_workers" , type = int , default = 300 , help = "Number of concurrent threads"
502+ "--max_workers" , type = int , default = 32 , help = "Number of concurrent threads"
424503 )
425504 parser .add_argument (
426505 "--use_url" ,
427506 action = "store_true" ,
428507 help = "Use URL-based audio fetching instead of datasets" ,
429508 )
509+ parser .add_argument (
510+ "--clear_cache" ,
511+ action = "store_true" ,
512+ help = "Clear the cache for this model/dataset combination before starting" ,
513+ )
430514
431515 args = parser .parse_args ()
432516
@@ -438,4 +522,5 @@ def process_sample(sample):
438522 use_url = args .use_url ,
439523 max_samples = args .max_samples ,
440524 max_workers = args .max_workers ,
525+ clear_cache = args .clear_cache ,
441526 )
0 commit comments