Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 90 additions & 8 deletions api/run_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import argparse
from typing import Optional
import json
import hashlib
from pathlib import Path
import datasets
import evaluate
import soundfile as sf
Expand All @@ -26,6 +29,55 @@
load_dotenv()


def get_cache_path(model_name, dataset_path, dataset, split):
cache_dir = Path(".cache/transcriptions")
cache_dir.mkdir(parents=True, exist_ok=True)

cache_key = f"{model_name}_{dataset_path}_{dataset}_{split}".replace("/", "_").replace(":", "_")
return cache_dir / f"{cache_key}.jsonl"


def load_cache(cache_path):
cached_results = {}
if cache_path.exists():
try:
with open(cache_path, "r") as f:
for line in f:
if line.strip():
entry = json.loads(line)
cached_results[entry["sample_id"]] = entry
print(f"Loaded {len(cached_results)} cached results from {cache_path}")
except Exception as e:
print(f"Warning: Error loading cache: {e}")
return cached_results


def save_to_cache(cache_path, sample_id, reference, prediction, audio_duration, transcription_time):
entry = {
"sample_id": sample_id,
"reference": reference,
"prediction": prediction,
"audio_duration": audio_duration,
"transcription_time": transcription_time
}

with open(cache_path, "a") as f:
f.write(json.dumps(entry) + "\n")


def get_sample_id(sample, index, use_url):
"""Generate a unique ID for a sample based on its content."""
if use_url:
id_str = f"{index}_{sample['row']['audio'][0]['src']}"
else:
# Use the text content for better uniqueness
text = sample.get('norm_text', sample.get('text', ''))
audio_len = len(sample.get('audio', {}).get('array', [])) if 'audio' in sample else index
id_str = f"{index}_{text[:50]}_{audio_len}"

return hashlib.md5(id_str.encode()).hexdigest()


def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20):
API_URL = "https://datasets-server.huggingface.co/rows"

Expand Down Expand Up @@ -256,7 +308,6 @@ def transcribe_with_retry(
transcript_text.append(element.value)

return "".join(transcript_text) if transcript_text else ""

else:
raise ValueError(
"Invalid model prefix, must start with 'assembly/', 'openai/', 'elevenlabs/' or 'revai/'"
Expand Down Expand Up @@ -289,17 +340,28 @@ def transcribe_dataset(
use_url=False,
max_samples=None,
max_workers=4,
clear_cache=False,
):
cache_path = get_cache_path(model_name, dataset_path, dataset, split)

if clear_cache and cache_path.exists():
print(f"Clearing cache file: {cache_path}")
cache_path.unlink()

cached_results = load_cache(cache_path)
print(f"Cache file: {cache_path}")

if use_url:
audio_rows = fetch_audio_urls(dataset_path, dataset, split)
if max_samples:
audio_rows = itertools.islice(audio_rows, max_samples)
ds = audio_rows
ds = list(enumerate(audio_rows))
else:
ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False)
ds = data_utils.prepare_data(ds)
if max_samples:
ds = ds.take(max_samples)
ds = list(enumerate(ds))

results = {
"references": [],
Expand All @@ -310,7 +372,14 @@ def transcribe_dataset(

print(f"Transcribing with model: {model_name}")

def process_sample(sample):
def process_sample(idx_sample):
index, sample = idx_sample
sample_id = get_sample_id(sample, index, use_url)

if sample_id in cached_results:
cached = cached_results[sample_id]
return cached["reference"], cached["prediction"], cached["audio_duration"], cached["transcription_time"]

if use_url:
reference = sample["row"]["text"].strip() or " "
audio_duration = sample["row"]["audio_length_s"]
Expand Down Expand Up @@ -353,8 +422,17 @@ def process_sample(sample):
print(f"File {tmp_path} does not exist")

transcription_time = time.time() - start
return reference, transcription, audio_duration, transcription_time


normalized_reference = data_utils.normalizer(reference) or " "
normalized_prediction = data_utils.normalizer(transcription) or " "

save_to_cache(cache_path, sample_id, normalized_reference, normalized_prediction, audio_duration, transcription_time)

return normalized_reference, normalized_prediction, audio_duration, transcription_time

cached_count = sum(1 for idx, sample in ds if get_sample_id(sample, idx, use_url) in cached_results)
print(f"Skipping {cached_count} cached samples, processing {len(ds) - cached_count} new samples")

with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_sample = {
executor.submit(process_sample, sample): sample for sample in ds
Expand All @@ -371,15 +449,13 @@ def process_sample(sample):
results["references"].append(reference)
results["audio_length_s"].append(audio_duration)
results["transcription_time_s"].append(transcription_time)

results["predictions"] = [
data_utils.normalizer(transcription) or " "
for transcription in results["predictions"]
]
results["references"] = [
data_utils.normalizer(reference) or " " for reference in results["references"]
]

manifest_path = data_utils.write_manifest(
results["references"],
results["predictions"],
Expand Down Expand Up @@ -420,13 +496,18 @@ def process_sample(sample):
)
parser.add_argument("--max_samples", type=int, default=None)
parser.add_argument(
"--max_workers", type=int, default=300, help="Number of concurrent threads"
"--max_workers", type=int, default=32, help="Number of concurrent threads"
)
parser.add_argument(
"--use_url",
action="store_true",
help="Use URL-based audio fetching instead of datasets",
)
parser.add_argument(
"--clear_cache",
action="store_true",
help="Clear the cache for this model/dataset combination before starting",
)

args = parser.parse_args()

Expand All @@ -438,4 +519,5 @@ def process_sample(sample):
use_url=args.use_url,
max_samples=args.max_samples,
max_workers=args.max_workers,
clear_cache=args.clear_cache,
)