Skip to content

Commit f95904b

Browse files
author
“Jack
committed
add cache for api
1 parent ce4cfb6 commit f95904b

File tree

1 file changed

+90
-8
lines changed

1 file changed

+90
-8
lines changed

api/run_eval.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import argparse
22
from typing import Optional
3+
import json
4+
import hashlib
5+
from pathlib import Path
36
import datasets
47
import evaluate
58
import soundfile as sf
@@ -26,6 +29,55 @@
2629
load_dotenv()
2730

2831

32+
def get_cache_path(model_name, dataset_path, dataset, split):
33+
cache_dir = Path(".cache/transcriptions")
34+
cache_dir.mkdir(parents=True, exist_ok=True)
35+
36+
cache_key = f"{model_name}_{dataset_path}_{dataset}_{split}".replace("/", "_").replace(":", "_")
37+
return cache_dir / f"{cache_key}.jsonl"
38+
39+
40+
def load_cache(cache_path):
41+
cached_results = {}
42+
if cache_path.exists():
43+
try:
44+
with open(cache_path, "r") as f:
45+
for line in f:
46+
if line.strip():
47+
entry = json.loads(line)
48+
cached_results[entry["sample_id"]] = entry
49+
print(f"Loaded {len(cached_results)} cached results from {cache_path}")
50+
except Exception as e:
51+
print(f"Warning: Error loading cache: {e}")
52+
return cached_results
53+
54+
55+
def save_to_cache(cache_path, sample_id, reference, prediction, audio_duration, transcription_time):
56+
entry = {
57+
"sample_id": sample_id,
58+
"reference": reference,
59+
"prediction": prediction,
60+
"audio_duration": audio_duration,
61+
"transcription_time": transcription_time
62+
}
63+
64+
with open(cache_path, "a") as f:
65+
f.write(json.dumps(entry) + "\n")
66+
67+
68+
def get_sample_id(sample, index, use_url):
69+
"""Generate a unique ID for a sample based on its content."""
70+
if use_url:
71+
id_str = f"{index}_{sample['row']['audio'][0]['src']}"
72+
else:
73+
# Use the text content for better uniqueness
74+
text = sample.get('norm_text', sample.get('text', ''))
75+
audio_len = len(sample.get('audio', {}).get('array', [])) if 'audio' in sample else index
76+
id_str = f"{index}_{text[:50]}_{audio_len}"
77+
78+
return hashlib.md5(id_str.encode()).hexdigest()
79+
80+
2981
def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20):
3082
API_URL = "https://datasets-server.huggingface.co/rows"
3183

@@ -256,7 +308,6 @@ def transcribe_with_retry(
256308
transcript_text.append(element.value)
257309

258310
return "".join(transcript_text) if transcript_text else ""
259-
260311
else:
261312
raise ValueError(
262313
"Invalid model prefix, must start with 'assembly/', 'openai/', 'elevenlabs/' or 'revai/'"
@@ -289,17 +340,28 @@ def transcribe_dataset(
289340
use_url=False,
290341
max_samples=None,
291342
max_workers=4,
343+
clear_cache=False,
292344
):
345+
cache_path = get_cache_path(model_name, dataset_path, dataset, split)
346+
347+
if clear_cache and cache_path.exists():
348+
print(f"Clearing cache file: {cache_path}")
349+
cache_path.unlink()
350+
351+
cached_results = load_cache(cache_path)
352+
print(f"Cache file: {cache_path}")
353+
293354
if use_url:
294355
audio_rows = fetch_audio_urls(dataset_path, dataset, split)
295356
if max_samples:
296357
audio_rows = itertools.islice(audio_rows, max_samples)
297-
ds = audio_rows
358+
ds = list(enumerate(audio_rows))
298359
else:
299360
ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False)
300361
ds = data_utils.prepare_data(ds)
301362
if max_samples:
302363
ds = ds.take(max_samples)
364+
ds = list(enumerate(ds))
303365

304366
results = {
305367
"references": [],
@@ -310,7 +372,14 @@ def transcribe_dataset(
310372

311373
print(f"Transcribing with model: {model_name}")
312374

313-
def process_sample(sample):
375+
def process_sample(idx_sample):
376+
index, sample = idx_sample
377+
sample_id = get_sample_id(sample, index, use_url)
378+
379+
if sample_id in cached_results:
380+
cached = cached_results[sample_id]
381+
return cached["reference"], cached["prediction"], cached["audio_duration"], cached["transcription_time"]
382+
314383
if use_url:
315384
reference = sample["row"]["text"].strip() or " "
316385
audio_duration = sample["row"]["audio_length_s"]
@@ -353,8 +422,17 @@ def process_sample(sample):
353422
print(f"File {tmp_path} does not exist")
354423

355424
transcription_time = time.time() - start
356-
return reference, transcription, audio_duration, transcription_time
357-
425+
426+
normalized_reference = data_utils.normalizer(reference) or " "
427+
normalized_prediction = data_utils.normalizer(transcription) or " "
428+
429+
save_to_cache(cache_path, sample_id, normalized_reference, normalized_prediction, audio_duration, transcription_time)
430+
431+
return normalized_reference, normalized_prediction, audio_duration, transcription_time
432+
433+
cached_count = sum(1 for idx, sample in ds if get_sample_id(sample, idx, use_url) in cached_results)
434+
print(f"Skipping {cached_count} cached samples, processing {len(ds) - cached_count} new samples")
435+
358436
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
359437
future_to_sample = {
360438
executor.submit(process_sample, sample): sample for sample in ds
@@ -371,15 +449,13 @@ def process_sample(sample):
371449
results["references"].append(reference)
372450
results["audio_length_s"].append(audio_duration)
373451
results["transcription_time_s"].append(transcription_time)
374-
375452
results["predictions"] = [
376453
data_utils.normalizer(transcription) or " "
377454
for transcription in results["predictions"]
378455
]
379456
results["references"] = [
380457
data_utils.normalizer(reference) or " " for reference in results["references"]
381458
]
382-
383459
manifest_path = data_utils.write_manifest(
384460
results["references"],
385461
results["predictions"],
@@ -420,13 +496,18 @@ def process_sample(sample):
420496
)
421497
parser.add_argument("--max_samples", type=int, default=None)
422498
parser.add_argument(
423-
"--max_workers", type=int, default=300, help="Number of concurrent threads"
499+
"--max_workers", type=int, default=32, help="Number of concurrent threads"
424500
)
425501
parser.add_argument(
426502
"--use_url",
427503
action="store_true",
428504
help="Use URL-based audio fetching instead of datasets",
429505
)
506+
parser.add_argument(
507+
"--clear_cache",
508+
action="store_true",
509+
help="Clear the cache for this model/dataset combination before starting",
510+
)
430511

431512
args = parser.parse_args()
432513

@@ -438,4 +519,5 @@ def process_sample(sample):
438519
use_url=args.use_url,
439520
max_samples=args.max_samples,
440521
max_workers=args.max_workers,
522+
clear_cache=args.clear_cache,
441523
)

0 commit comments

Comments
 (0)