Skip to content

Commit b5b520b

Browse files
author
“Jack
committed
add cache for api
1 parent ce4cfb6 commit b5b520b

File tree

1 file changed

+93
-8
lines changed

1 file changed

+93
-8
lines changed

api/run_eval.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import argparse
22
from typing import Optional
3+
import json
4+
import hashlib
5+
from pathlib import Path
36
import datasets
47
import evaluate
58
import soundfile as sf
@@ -26,6 +29,55 @@
2629
load_dotenv()
2730

2831

32+
def get_cache_path(model_name, dataset_path, dataset, split):
33+
cache_dir = Path(".cache/transcriptions")
34+
cache_dir.mkdir(parents=True, exist_ok=True)
35+
36+
cache_key = f"{model_name}_{dataset_path}_{dataset}_{split}".replace("/", "_").replace(":", "_")
37+
return cache_dir / f"{cache_key}.jsonl"
38+
39+
40+
def load_cache(cache_path):
41+
cached_results = {}
42+
if cache_path.exists():
43+
try:
44+
with open(cache_path, "r") as f:
45+
for line in f:
46+
if line.strip():
47+
entry = json.loads(line)
48+
cached_results[entry["sample_id"]] = entry
49+
print(f"Loaded {len(cached_results)} cached results from {cache_path}")
50+
except Exception as e:
51+
print(f"Warning: Error loading cache: {e}")
52+
return cached_results
53+
54+
55+
def save_to_cache(cache_path, sample_id, reference, prediction, audio_duration, transcription_time):
56+
entry = {
57+
"sample_id": sample_id,
58+
"reference": reference,
59+
"prediction": prediction,
60+
"audio_duration": audio_duration,
61+
"transcription_time": transcription_time
62+
}
63+
64+
with open(cache_path, "a") as f:
65+
f.write(json.dumps(entry) + "\n")
66+
67+
68+
def get_sample_id(sample, index, use_url):
69+
"""Generate a unique ID for a sample based on its content."""
70+
if use_url:
71+
id_str = f"{index}_{sample['row']['audio'][0]['src']}"
72+
else:
73+
# Use the text content for better uniqueness
74+
text = sample.get('norm_text', sample.get('text', ''))
75+
audio_len = len(sample.get('audio', {}).get('array', [])) if 'audio' in sample else index
76+
id_str = f"{index}_{text[:50]}_{audio_len}"
77+
78+
return hashlib.md5(id_str.encode()).hexdigest()
79+
80+
2981
def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20):
3082
API_URL = "https://datasets-server.huggingface.co/rows"
3183

@@ -256,7 +308,6 @@ def transcribe_with_retry(
256308
transcript_text.append(element.value)
257309

258310
return "".join(transcript_text) if transcript_text else ""
259-
260311
else:
261312
raise ValueError(
262313
"Invalid model prefix, must start with 'assembly/', 'openai/', 'elevenlabs/' or 'revai/'"
@@ -289,17 +340,28 @@ def transcribe_dataset(
289340
use_url=False,
290341
max_samples=None,
291342
max_workers=4,
343+
clear_cache=False,
292344
):
345+
cache_path = get_cache_path(model_name, dataset_path, dataset, split)
346+
347+
if clear_cache and cache_path.exists():
348+
print(f"Clearing cache file: {cache_path}")
349+
cache_path.unlink()
350+
351+
cached_results = load_cache(cache_path)
352+
print(f"Cache file: {cache_path}")
353+
293354
if use_url:
294355
audio_rows = fetch_audio_urls(dataset_path, dataset, split)
295356
if max_samples:
296357
audio_rows = itertools.islice(audio_rows, max_samples)
297-
ds = audio_rows
358+
ds = list(enumerate(audio_rows))
298359
else:
299360
ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False)
300361
ds = data_utils.prepare_data(ds)
301362
if max_samples:
302363
ds = ds.take(max_samples)
364+
ds = list(enumerate(ds))
303365

304366
results = {
305367
"references": [],
@@ -310,7 +372,15 @@ def transcribe_dataset(
310372

311373
print(f"Transcribing with model: {model_name}")
312374

313-
def process_sample(sample):
375+
def process_sample(idx_sample):
376+
index, sample = idx_sample
377+
sample_id = get_sample_id(sample, index, use_url)
378+
379+
# Check if already cached
380+
if sample_id in cached_results:
381+
cached = cached_results[sample_id]
382+
return cached["reference"], cached["prediction"], cached["audio_duration"], cached["transcription_time"]
383+
314384
if use_url:
315385
reference = sample["row"]["text"].strip() or " "
316386
audio_duration = sample["row"]["audio_length_s"]
@@ -353,8 +423,19 @@ def process_sample(sample):
353423
print(f"File {tmp_path} does not exist")
354424

355425
transcription_time = time.time() - start
356-
return reference, transcription, audio_duration, transcription_time
357-
426+
427+
# Normalize before caching
428+
normalized_reference = data_utils.normalizer(reference) or " "
429+
normalized_prediction = data_utils.normalizer(transcription) or " "
430+
431+
# Save to cache
432+
save_to_cache(cache_path, sample_id, normalized_reference, normalized_prediction, audio_duration, transcription_time)
433+
434+
return normalized_reference, normalized_prediction, audio_duration, transcription_time
435+
436+
cached_count = sum(1 for idx, sample in ds if get_sample_id(sample, idx, use_url) in cached_results)
437+
print(f"Skipping {cached_count} cached samples, processing {len(ds) - cached_count} new samples")
438+
358439
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
359440
future_to_sample = {
360441
executor.submit(process_sample, sample): sample for sample in ds
@@ -371,15 +452,13 @@ def process_sample(sample):
371452
results["references"].append(reference)
372453
results["audio_length_s"].append(audio_duration)
373454
results["transcription_time_s"].append(transcription_time)
374-
375455
results["predictions"] = [
376456
data_utils.normalizer(transcription) or " "
377457
for transcription in results["predictions"]
378458
]
379459
results["references"] = [
380460
data_utils.normalizer(reference) or " " for reference in results["references"]
381461
]
382-
383462
manifest_path = data_utils.write_manifest(
384463
results["references"],
385464
results["predictions"],
@@ -420,13 +499,18 @@ def process_sample(sample):
420499
)
421500
parser.add_argument("--max_samples", type=int, default=None)
422501
parser.add_argument(
423-
"--max_workers", type=int, default=300, help="Number of concurrent threads"
502+
"--max_workers", type=int, default=32, help="Number of concurrent threads"
424503
)
425504
parser.add_argument(
426505
"--use_url",
427506
action="store_true",
428507
help="Use URL-based audio fetching instead of datasets",
429508
)
509+
parser.add_argument(
510+
"--clear_cache",
511+
action="store_true",
512+
help="Clear the cache for this model/dataset combination before starting",
513+
)
430514

431515
args = parser.parse_args()
432516

@@ -438,4 +522,5 @@ def process_sample(sample):
438522
use_url=args.use_url,
439523
max_samples=args.max_samples,
440524
max_workers=args.max_workers,
525+
clear_cache=args.clear_cache,
441526
)

0 commit comments

Comments
 (0)