diff --git a/api/run_eval.py b/api/run_eval.py index 69b588a..7e71787 100644 --- a/api/run_eval.py +++ b/api/run_eval.py @@ -1,5 +1,8 @@ import argparse from typing import Optional +import json +import hashlib +from pathlib import Path import datasets import evaluate import soundfile as sf @@ -26,6 +29,55 @@ load_dotenv() +def get_cache_path(model_name, dataset_path, dataset, split): + cache_dir = Path(".cache/transcriptions") + cache_dir.mkdir(parents=True, exist_ok=True) + + cache_key = f"{model_name}_{dataset_path}_{dataset}_{split}".replace("/", "_").replace(":", "_") + return cache_dir / f"{cache_key}.jsonl" + + +def load_cache(cache_path): + cached_results = {} + if cache_path.exists(): + try: + with open(cache_path, "r") as f: + for line in f: + if line.strip(): + entry = json.loads(line) + cached_results[entry["sample_id"]] = entry + print(f"Loaded {len(cached_results)} cached results from {cache_path}") + except Exception as e: + print(f"Warning: Error loading cache: {e}") + return cached_results + + +def save_to_cache(cache_path, sample_id, reference, prediction, audio_duration, transcription_time): + entry = { + "sample_id": sample_id, + "reference": reference, + "prediction": prediction, + "audio_duration": audio_duration, + "transcription_time": transcription_time + } + + with open(cache_path, "a") as f: + f.write(json.dumps(entry) + "\n") + + +def get_sample_id(sample, index, use_url): + """Generate a unique ID for a sample based on its content.""" + if use_url: + id_str = f"{index}_{sample['row']['audio'][0]['src']}" + else: + # Use the text content for better uniqueness + text = sample.get('norm_text', sample.get('text', '')) + audio_len = len(sample.get('audio', {}).get('array', [])) if 'audio' in sample else index + id_str = f"{index}_{text[:50]}_{audio_len}" + + return hashlib.md5(id_str.encode()).hexdigest() + + def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20): API_URL = "https://datasets-server.huggingface.co/rows" @@ -256,7 +308,6 @@ def transcribe_with_retry( transcript_text.append(element.value) return "".join(transcript_text) if transcript_text else "" - else: raise ValueError( "Invalid model prefix, must start with 'assembly/', 'openai/', 'elevenlabs/' or 'revai/'" @@ -289,17 +340,28 @@ def transcribe_dataset( use_url=False, max_samples=None, max_workers=4, + clear_cache=False, ): + cache_path = get_cache_path(model_name, dataset_path, dataset, split) + + if clear_cache and cache_path.exists(): + print(f"Clearing cache file: {cache_path}") + cache_path.unlink() + + cached_results = load_cache(cache_path) + print(f"Cache file: {cache_path}") + if use_url: audio_rows = fetch_audio_urls(dataset_path, dataset, split) if max_samples: audio_rows = itertools.islice(audio_rows, max_samples) - ds = audio_rows + ds = list(enumerate(audio_rows)) else: ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False) ds = data_utils.prepare_data(ds) if max_samples: ds = ds.take(max_samples) + ds = list(enumerate(ds)) results = { "references": [], @@ -310,7 +372,14 @@ def transcribe_dataset( print(f"Transcribing with model: {model_name}") - def process_sample(sample): + def process_sample(idx_sample): + index, sample = idx_sample + sample_id = get_sample_id(sample, index, use_url) + + if sample_id in cached_results: + cached = cached_results[sample_id] + return cached["reference"], cached["prediction"], cached["audio_duration"], cached["transcription_time"] + if use_url: reference = sample["row"]["text"].strip() or " " audio_duration = sample["row"]["audio_length_s"] @@ -353,8 +422,17 @@ def process_sample(sample): print(f"File {tmp_path} does not exist") transcription_time = time.time() - start - return reference, transcription, audio_duration, transcription_time - + + normalized_reference = data_utils.normalizer(reference) or " " + normalized_prediction = data_utils.normalizer(transcription) or " " + + save_to_cache(cache_path, sample_id, normalized_reference, normalized_prediction, audio_duration, transcription_time) + + return normalized_reference, normalized_prediction, audio_duration, transcription_time + + cached_count = sum(1 for idx, sample in ds if get_sample_id(sample, idx, use_url) in cached_results) + print(f"Skipping {cached_count} cached samples, processing {len(ds) - cached_count} new samples") + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_sample = { executor.submit(process_sample, sample): sample for sample in ds @@ -371,7 +449,6 @@ def process_sample(sample): results["references"].append(reference) results["audio_length_s"].append(audio_duration) results["transcription_time_s"].append(transcription_time) - results["predictions"] = [ data_utils.normalizer(transcription) or " " for transcription in results["predictions"] @@ -379,7 +456,6 @@ def process_sample(sample): results["references"] = [ data_utils.normalizer(reference) or " " for reference in results["references"] ] - manifest_path = data_utils.write_manifest( results["references"], results["predictions"], @@ -420,13 +496,18 @@ def process_sample(sample): ) parser.add_argument("--max_samples", type=int, default=None) parser.add_argument( - "--max_workers", type=int, default=300, help="Number of concurrent threads" + "--max_workers", type=int, default=32, help="Number of concurrent threads" ) parser.add_argument( "--use_url", action="store_true", help="Use URL-based audio fetching instead of datasets", ) + parser.add_argument( + "--clear_cache", + action="store_true", + help="Clear the cache for this model/dataset combination before starting", + ) args = parser.parse_args() @@ -438,4 +519,5 @@ def process_sample(sample): use_url=args.use_url, max_samples=args.max_samples, max_workers=args.max_workers, + clear_cache=args.clear_cache, )