From e392f8bc55de18c019d764473ac6bf76e1a4b839 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:24:31 +0000 Subject: [PATCH 1/3] Create the `run_gemini.sh` Execution Script: I have created the script to run both English and multilingual benchmarks for the Gemini models. The script also includes the logic to score the results after each model's evaluation is complete. I have added `--max_samples 2` for testing. --- gemini/requirements_gemini.txt | 1 + gemini/run_eval.py | 277 +++++++++++++++++++++++++++++++++ gemini/run_eval_ml.py | 255 ++++++++++++++++++++++++++++++ gemini/run_gemini.sh | 102 ++++++++++++ 4 files changed, 635 insertions(+) create mode 100644 gemini/requirements_gemini.txt create mode 100644 gemini/run_eval.py create mode 100644 gemini/run_eval_ml.py create mode 100644 gemini/run_gemini.sh diff --git a/gemini/requirements_gemini.txt b/gemini/requirements_gemini.txt new file mode 100644 index 0000000..27f8b2f --- /dev/null +++ b/gemini/requirements_gemini.txt @@ -0,0 +1 @@ +google-generativeai diff --git a/gemini/run_eval.py b/gemini/run_eval.py new file mode 100644 index 0000000..4417d55 --- /dev/null +++ b/gemini/run_eval.py @@ -0,0 +1,277 @@ +import argparse +from typing import Optional +import datasets +import evaluate +import soundfile as sf +import tempfile +import time +import os +import requests +import itertools +from tqdm import tqdm +from io import BytesIO +from normalizer import data_utils +import concurrent.futures +import getpass +import google.generativeai as genai + + +def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20): + API_URL = "https://datasets-server.huggingface.co/rows" + + size_url = f"https://datasets-server.huggingface.co/size?dataset={dataset_path}&config={dataset}&split={split}" + size_response = requests.get(size_url).json() + total_rows = size_response["size"]["config"]["num_rows"] + audio_urls = [] + for offset in tqdm(range(0, total_rows, batch_size), desc="Fetching audio URLs"): + params = { + "dataset": dataset_path, + "config": dataset, + "split": split, + "offset": offset, + "length": min(batch_size, total_rows - offset), + } + + retries = 0 + while retries <= max_retries: + try: + headers = {} + if os.environ.get("HF_TOKEN") is not None: + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + else: + print("HF_TOKEN not set, might experience rate-limiting.") + response = requests.get(API_URL, params=params) + response.raise_for_status() + data = response.json() + yield from data["rows"] + break + except (requests.exceptions.RequestException, ValueError) as e: + retries += 1 + print( + f"Error fetching data: {e}, retrying ({retries}/{max_retries})..." + ) + time.sleep(10) + if retries >= max_retries: + raise Exception("Max retries exceeded while fetching data.") + + +def transcribe_with_retry( + model_name: str, + audio_file_path: Optional[str], + sample: dict, + max_retries=10, + use_url=False, +): + retries = 0 + while retries <= max_retries: + try: + if model_name.startswith("gemini/"): + model_id = model_name.split("/", 1)[1] + model = genai.GenerativeModel(model_id) + + if use_url: + # Download the audio file from the URL to a temporary file + response = requests.get(sample["row"]["audio"][0]["src"]) + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: + tmpfile.write(response.content) + audio_file_path = tmpfile.name + + # Upload the file to the Gemini API + gemini_file = genai.upload_file(path=audio_file_path) + + # Transcribe the audio + response = model.generate_content(["Generate a transcript of the speech.", gemini_file]) + + # Clean up the uploaded file + genai.delete_file(gemini_file.name) + + if use_url: + # Clean up the temporary file + os.unlink(audio_file_path) + + return response.text.strip() + else: + raise ValueError( + "Invalid model prefix, must start with 'gemini/'" + ) + + except Exception as e: + retries += 1 + if retries > max_retries: + raise e + + if not use_url: + sf.write( + audio_file_path, + sample["audio"]["array"], + sample["audio"]["sampling_rate"], + format="WAV", + ) + delay = 1 + print( + f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})" + ) + time.sleep(delay) + + +def transcribe_dataset( + dataset_path, + dataset, + split, + model_name, + use_url=False, + max_samples=None, + max_workers=4, +): + if use_url: + audio_rows = fetch_audio_urls(dataset_path, dataset, split) + if max_samples: + audio_rows = itertools.islice(audio_rows, max_samples) + ds = audio_rows + else: + ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False) + ds = data_utils.prepare_data(ds) + if max_samples: + ds = ds.take(max_samples) + + results = { + "references": [], + "predictions": [], + "audio_length_s": [], + "transcription_time_s": [], + } + + print(f"Transcribing with model: {model_name}") + + def process_sample(sample): + if use_url: + reference = sample["row"]["text"].strip() or " " + audio_duration = sample["row"]["audio_length_s"] + start = time.time() + try: + transcription = transcribe_with_retry( + model_name, None, sample, use_url=True + ) + except Exception as e: + print(f"Failed to transcribe after retries: {e}") + return None + + else: + reference = sample.get("norm_text", "").strip() or " " + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: + sf.write( + tmpfile.name, + sample["audio"]["array"], + sample["audio"]["sampling_rate"], + format="WAV", + ) + tmp_path = tmpfile.name + audio_duration = ( + len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"] + ) + + start = time.time() + try: + transcription = transcribe_with_retry( + model_name, tmp_path, sample, use_url=False + ) + except Exception as e: + print(f"Failed to transcribe after retries: {e}") + os.unlink(tmp_path) + return None + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + else: + print(f"File {tmp_path} does not exist") + + transcription_time = time.time() - start + return reference, transcription, audio_duration, transcription_time + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_sample = { + executor.submit(process_sample, sample): sample for sample in ds + } + for future in tqdm( + concurrent.futures.as_completed(future_to_sample), + total=len(future_to_sample), + desc="Transcribing", + ): + result = future.result() + if result: + reference, transcription, audio_duration, transcription_time = result + results["predictions"].append(transcription) + results["references"].append(reference) + results["audio_length_s"].append(audio_duration) + results["transcription_time_s"].append(transcription_time) + + results["predictions"] = [ + data_utils.normalizer(transcription) or " " + for transcription in results["predictions"] + ] + results["references"] = [ + data_utils.normalizer(reference) or " " for reference in results["references"] + ] + + manifest_path = data_utils.write_manifest( + results["references"], + results["predictions"], + model_name.replace("/", "-"), + dataset_path, + dataset, + split, + audio_length=results["audio_length_s"], + transcription_time=results["transcription_time_s"], + ) + + print("Results saved at path:", manifest_path) + + wer_metric = evaluate.load("wer") + wer = wer_metric.compute( + references=results["references"], predictions=results["predictions"] + ) + wer_percent = round(100 * wer, 2) + rtfx = round( + sum(results["audio_length_s"]) / sum(results["transcription_time_s"]), 2 + ) + + print("WER:", wer_percent, "%") + print("RTFx:", rtfx) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Unified Transcription Script with Concurrency" + ) + parser.add_argument("--dataset_path", required=True) + parser.add_argument("--dataset", required=True) + parser.add_argument("--split", default="test") + parser.add_argument( + "--model_name", + required=True, + help="Prefix model name with 'gemini/'", + ) + parser.add_argument("--max_samples", type=int, default=None) + parser.add_argument( + "--max_workers", type=int, default=300, help="Number of concurrent threads" + ) + parser.add_argument( + "--use_url", + action="store_true", + help="Use URL-based audio fetching instead of datasets", + ) + + args = parser.parse_args() + + gemini_api_key = getpass.getpass("Enter your Gemini API key: ") + genai.configure(api_key=gemini_api_key) + + transcribe_dataset( + dataset_path=args.dataset_path, + dataset=args.dataset, + split=args.split, + model_name=args.model_name, + use_url=args.use_url, + max_samples=args.max_samples, + max_workers=args.max_workers, + ) diff --git a/gemini/run_eval_ml.py b/gemini/run_eval_ml.py new file mode 100644 index 0000000..606227f --- /dev/null +++ b/gemini/run_eval_ml.py @@ -0,0 +1,255 @@ +import argparse +from typing import Optional +import datasets +import evaluate +import soundfile as sf +import tempfile +import time +import os +import requests +from tqdm import tqdm +from normalizer import data_utils +import concurrent.futures +import getpass +import google.generativeai as genai + + +def transcribe_with_retry( + model_name: str, + audio_file_path: Optional[str], + sample: dict, + max_retries=10, + use_url=False, # This is not used in the ml script, but we keep it for consistency with the function signature +): + retries = 0 + while retries <= max_retries: + try: + if model_name.startswith("gemini/"): + model_id = model_name.split("/", 1)[1] + model = genai.GenerativeModel(model_id) + + # In the multilingual script, we always have a local file + # so we don't need to handle the use_url case for downloading. + + # Upload the file to the Gemini API + gemini_file = genai.upload_file(path=audio_file_path) + + # Transcribe the audio + response = model.generate_content(["Generate a transcript of the speech.", gemini_file]) + + # Clean up the uploaded file + genai.delete_file(gemini_file.name) + + return response.text.strip() + else: + raise ValueError( + "Invalid model prefix, must start with 'gemini/'" + ) + + except Exception as e: + retries += 1 + if retries > max_retries: + raise e + + # Re-writing the file on failure is handled by the caller. + delay = 1 + print( + f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})" + ) + time.sleep(delay) + + +def transcribe_dataset( + dataset, + config_name, + language, + split, + model_name, + max_samples, + max_workers, + streaming, +): + + print(f"Loading dataset: {dataset} with config: {config_name}") + ds = datasets.load_dataset(dataset, config_name, split=split, streaming=streaming) + + if max_samples is not None and max_samples > 0: + print(f"Subsampling dataset to first {max_samples} samples!") + if streaming: + ds = ds.take(max_samples) + else: + ds = ds.select(range(min(max_samples, len(ds)))) + + + results = { + "references": [], + "predictions": [], + "audio_length_s": [], + "transcription_time_s": [], + } + + print(f"Transcribing with model: {model_name}") + + def process_sample(sample): + reference = sample.get("text", "").strip() or " " + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: + sf.write( + tmpfile.name, + sample["audio"]["array"], + sample["audio"]["sampling_rate"], + format="WAV", + ) + tmp_path = tmpfile.name + audio_duration = ( + len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"] + ) + + start = time.time() + try: + transcription = transcribe_with_retry( + model_name, tmp_path, sample + ) + except Exception as e: + print(f"Failed to transcribe after retries: {e}") + os.unlink(tmp_path) + return None + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + else: + print(f"File {tmp_path} does not exist") + + transcription_time = time.time() - start + return reference, transcription, audio_duration, transcription_time + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_sample = { + executor.submit(process_sample, sample): sample for sample in ds + } + for future in tqdm( + concurrent.futures.as_completed(future_to_sample), + total=len(future_to_sample), + desc="Transcribing", + ): + result = future.result() + if result: + reference, transcription, audio_duration, transcription_time = result + results["predictions"].append(transcription) + results["references"].append(reference) + results["audio_length_s"].append(audio_duration) + results["transcription_time_s"].append(transcription_time) + + if language == "en": + results["predictions"] = [ + data_utils.normalizer(transcription) or " " + for transcription in results["predictions"] + ] + results["references"] = [ + data_utils.normalizer(reference) or " " for reference in results["references"] + ] + else: + results["predictions"] = [ + data_utils.ml_normalizer(transcription) or " " + for transcription in results["predictions"] + ] + results["references"] = [ + data_utils.ml_normalizer(reference) or " " for reference in results["references"] + ] + + + manifest_path = data_utils.write_manifest( + results["references"], + results["predictions"], + model_name.replace("/", "-"), + dataset, + config_name, + split, + audio_length=results["audio_length_s"], + transcription_time=results["transcription_time_s"], + ) + + print("Results saved at path:", manifest_path) + + wer_metric = evaluate.load("wer") + wer = wer_metric.compute( + references=results["references"], predictions=results["predictions"] + ) + wer_percent = round(100 * wer, 2) + rtfx = round( + sum(results["audio_length_s"]) / sum(results["transcription_time_s"]), 2 + ) + + print("WER:", wer_percent, "%") + print("RTFx:", rtfx) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name", type=str, required=True, help="Model identifier. Should be loadable with Gemini.", + ) + parser.add_argument( + "--dataset", + type=str, + default="nithinraok/asr-leaderboard-datasets", + help="Dataset name. Default is 'nithinraok/asr-leaderboard-datasets'" + ) + parser.add_argument( + "--config_name", + type=str, + required=True, + help="Config name in format _ (e.g., fleurs_en, mcv_de, mls_es)" + ) + parser.add_argument( + "--language", + type=str, + default=None, + help="Language code (e.g., en, de, es). If not provided, will be extracted from config_name." + ) + parser.add_argument( + "--split", + type=str, + default="test", + help="Split of the dataset. Default is 'test'.", + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", + ) + parser.add_argument( + "--max_workers", type=int, default=4, help="Number of concurrent threads" + ) + parser.add_argument( + "--no-streaming", + dest='streaming', + action="store_false", + help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", + ) + args = parser.parse_args() + parser.set_defaults(streaming=True) + + if args.language is None: + try: + args.language = args.config_name.split('_', 1)[1] + except IndexError: + raise ValueError("Language could not be inferred from config_name. Please specify it with --language.") + + print(f"Detected language: {args.language}") + + gemini_api_key = getpass.getpass("Enter your Gemini API key: ") + genai.configure(api_key=gemini_api_key) + + transcribe_dataset( + dataset=args.dataset, + config_name=args.config_name, + language=args.language, + split=args.split, + model_name=args.model_name, + max_samples=args.max_samples, + max_workers=args.max_workers, + streaming=args.streaming, + ) diff --git a/gemini/run_gemini.sh b/gemini/run_gemini.sh new file mode 100644 index 0000000..8e085e7 --- /dev/null +++ b/gemini/run_gemini.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +export PYTHONPATH="..":$PYTHONPATH + +MODEL_IDs=( + "gemini/gemini-2.5-pro" + "gemini/gemini-2.5-flash" +) + +for MODEL_ID in "${MODEL_IDs[@]}" +do + echo "--- Running Benchmarks for $MODEL_ID ---" + + # --- English Benchmarks --- + echo "--- Running English Benchmarks ---" + python gemini/run_eval.py \ + --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ + --dataset="ami" \ + --split="test" \ + --model_name ${MODEL_ID} \ + --max_samples 2 + + python gemini/run_eval.py \ + --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ + --dataset="earnings22" \ + --split="test" \ + --model_name ${MODEL_ID} \ + --max_samples 2 + + python gemini/run_eval.py \ + --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ + --dataset="gigaspeech" \ + --split="test" \ + --model_name ${MODEL_ID} \ + --max_samples 2 + + python gemini/run_eval.py \ + --dataset_path "hf-audio/esb-datasets-test-only-sorted" \ + --dataset "librispeech" \ + --split "test.clean" \ + --model_name ${MODEL_ID} \ + --max_samples 2 + + python gemini/run_eval.py \ + --dataset_path "hf-audio/esb-datasets-test-only-sorted" \ + --dataset "librispeech" \ + --split "test.other" \ + --model_name ${MODEL_ID} \ + --max_samples 2 + + python gemini/run_eval.py \ + --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ + --dataset="spgispeech" \ + --split="test" \ + --model_name ${MODEL_ID} \ + --max_samples 2 + + python gemini/run_eval.py \ + --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ + --dataset="tedlium" \ + --split="test" \ + --model_name ${MODEL_ID} \ + --max_samples 2 + + python gemini/run_eval.py \ + --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ + --dataset="voxpopuli" \ + --split="test" \ + --model_name ${MODEL_ID} \ + --max_samples 2 + + # --- Multilingual Benchmarks --- + echo "--- Running Multilingual Benchmarks ---" + declare -A EVAL_DATASETS + EVAL_DATASETS["fleurs"]="en de fr it es pt" + EVAL_DATASETS["mcv"]="en de es fr it" + EVAL_DATASETS["mls"]="es fr it pt" + + for dataset in ${!EVAL_DATASETS[@]}; do + for language in ${EVAL_DATASETS[$dataset]}; do + config_name="${dataset}_${language}" + echo "Running evaluation for $config_name" + python gemini/run_eval_ml.py \ + --model_name="$MODEL_ID" \ + --dataset="nithinraok/asr-leaderboard-datasets" \ + --config_name="$config_name" \ + --language="$language" \ + --split="test" \ + --max_samples 2 + done + done + + # --- Scoring --- + echo "--- Scoring results for $MODEL_ID ---" + RUNDIR=$(pwd) + cd normalizer && \ + python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \ + cd $RUNDIR + +done + +echo "--- All benchmarks complete ---" From bc7b02042abbef4a0c9f44c3d178d7b61ff34cbb Mon Sep 17 00:00:00 2001 From: Alexander Immler Date: Fri, 22 Aug 2025 12:03:20 +0200 Subject: [PATCH 2/3] gemini: align eval with template/canary; standardize CLI flags; avoid torchcodec decode; load .env in run script --- gemini/run_eval.py | 397 ++++++++++++++++++------------------------ gemini/run_eval_ml.py | 248 ++++++++++++++++---------- gemini/run_gemini.sh | 145 ++++++++------- 3 files changed, 415 insertions(+), 375 deletions(-) diff --git a/gemini/run_eval.py b/gemini/run_eval.py index 4417d55..9419988 100644 --- a/gemini/run_eval.py +++ b/gemini/run_eval.py @@ -1,277 +1,228 @@ import argparse -from typing import Optional -import datasets -import evaluate -import soundfile as sf -import tempfile -import time +import io import os -import requests -import itertools -from tqdm import tqdm -from io import BytesIO -from normalizer import data_utils -import concurrent.futures +import time import getpass -import google.generativeai as genai - -def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20): - API_URL = "https://datasets-server.huggingface.co/rows" +import evaluate +import datasets +import numpy as np +import soundfile as sf +from tqdm import tqdm - size_url = f"https://datasets-server.huggingface.co/size?dataset={dataset_path}&config={dataset}&split={split}" - size_response = requests.get(size_url).json() - total_rows = size_response["size"]["config"]["num_rows"] - audio_urls = [] - for offset in tqdm(range(0, total_rows, batch_size), desc="Fetching audio URLs"): - params = { - "dataset": dataset_path, - "config": dataset, - "split": split, - "offset": offset, - "length": min(batch_size, total_rows - offset), - } +from normalizer import data_utils - retries = 0 - while retries <= max_retries: - try: - headers = {} - if os.environ.get("HF_TOKEN") is not None: - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" - else: - print("HF_TOKEN not set, might experience rate-limiting.") - response = requests.get(API_URL, params=params) - response.raise_for_status() - data = response.json() - yield from data["rows"] - break - except (requests.exceptions.RequestException, ValueError) as e: - retries += 1 - print( - f"Error fetching data: {e}, retrying ({retries}/{max_retries})..." - ) - time.sleep(10) - if retries >= max_retries: - raise Exception("Max retries exceeded while fetching data.") +try: + import google.generativeai as genai +except ImportError: + print("Error: google-generativeai not installed. Run: pip install google-generativeai") + exit(1) def transcribe_with_retry( - model_name: str, - audio_file_path: Optional[str], - sample: dict, - max_retries=10, - use_url=False, -): + model_id: str, + audio_file_path: str, + max_retries: int = 10, +) -> str: retries = 0 while retries <= max_retries: try: - if model_name.startswith("gemini/"): - model_id = model_name.split("/", 1)[1] - model = genai.GenerativeModel(model_id) - - if use_url: - # Download the audio file from the URL to a temporary file - response = requests.get(sample["row"]["audio"][0]["src"]) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: - tmpfile.write(response.content) - audio_file_path = tmpfile.name + if model_id.startswith("gemini/"): + _model = model_id.split("/", 1)[1] + model = genai.GenerativeModel(_model) - # Upload the file to the Gemini API gemini_file = genai.upload_file(path=audio_file_path) - - # Transcribe the audio - response = model.generate_content(["Generate a transcript of the speech.", gemini_file]) - - # Clean up the uploaded file + response = model.generate_content([ + "Generate a transcript of the speech.", + gemini_file, + ]) genai.delete_file(gemini_file.name) - if use_url: - # Clean up the temporary file - os.unlink(audio_file_path) - - return response.text.strip() + return response.text.strip() if getattr(response, "text", None) else "" else: - raise ValueError( - "Invalid model prefix, must start with 'gemini/'" - ) + raise ValueError("Invalid model prefix, must start with 'gemini/'") except Exception as e: retries += 1 if retries > max_retries: raise e - if not use_url: - sf.write( - audio_file_path, - sample["audio"]["array"], - sample["audio"]["sampling_rate"], - format="WAV", - ) - delay = 1 + delay = min(2 ** retries, 30) # Exponential backoff with max 30s print( f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})" ) time.sleep(delay) + + # This should never be reached, but adding for type safety + return "" -def transcribe_dataset( - dataset_path, - dataset, - split, - model_name, - use_url=False, - max_samples=None, - max_workers=4, -): - if use_url: - audio_rows = fetch_audio_urls(dataset_path, dataset, split) - if max_samples: - audio_rows = itertools.islice(audio_rows, max_samples) - ds = audio_rows - else: - ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False) - ds = data_utils.prepare_data(ds) - if max_samples: - ds = ds.take(max_samples) +def main(args): + DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache") + CACHE_DIR = os.path.join(DATA_CACHE_DIR, args.dataset, args.split) + os.makedirs(CACHE_DIR, exist_ok=True) - results = { - "references": [], - "predictions": [], - "audio_length_s": [], - "transcription_time_s": [], - } - - print(f"Transcribing with model: {model_name}") - - def process_sample(sample): - if use_url: - reference = sample["row"]["text"].strip() or " " - audio_duration = sample["row"]["audio_length_s"] - start = time.time() - try: - transcription = transcribe_with_retry( - model_name, None, sample, use_url=True - ) - except Exception as e: - print(f"Failed to transcribe after retries: {e}") - return None - - else: - reference = sample.get("norm_text", "").strip() or " " - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: - sf.write( - tmpfile.name, - sample["audio"]["array"], - sample["audio"]["sampling_rate"], - format="WAV", - ) - tmp_path = tmpfile.name - audio_duration = ( - len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"] - ) + # Load dataset without triggering audio decoding (avoid torchcodec) + ds = datasets.load_dataset( + args.dataset_path, + args.dataset, + split=args.split, + streaming=False, + token=True, + ) + # Keep audio as filepaths to avoid decoding here + try: + from datasets import Audio + ds = ds.cast_column("audio", Audio(decode=False)) + except Exception: + pass + + # Subsample + if args.max_eval_samples is not None and args.max_eval_samples > 0: + print(f"Subsampling dataset to first {args.max_eval_samples} samples!") + if hasattr(ds, "select") and hasattr(ds, "__len__"): + ds = ds.select(range(min(args.max_eval_samples, len(ds)))) + + references = [] + audio_paths = [] + durations = [] + + for sample in tqdm(ds, desc="Preparing samples"): + sid = str(sample.get("id", "sample")).replace("/", "_").removesuffix(".wav") + audio_info = sample.get("audio") + if not isinstance(audio_info, dict): + print("Skipping sample without audio info") + continue + try: + if audio_info.get("bytes") is not None: + with io.BytesIO(audio_info["bytes"]) as bio: + audio_array, sr = sf.read(bio, dtype="float32") + elif audio_info.get("path"): + audio_array, sr = sf.read(audio_info["path"], dtype="float32") + elif audio_info.get("array") is not None: + audio_array = np.float32(audio_info["array"]) if not isinstance(audio_info["array"], np.ndarray) else audio_info["array"].astype(np.float32) + sr = audio_info.get("sampling_rate", 16000) + else: + print("Skipping sample: unsupported audio format") + continue + except Exception as e: + print(f"Failed to read audio: {e}") + continue - start = time.time() - try: - transcription = transcribe_with_retry( - model_name, tmp_path, sample, use_url=False - ) - except Exception as e: - print(f"Failed to transcribe after retries: {e}") - os.unlink(tmp_path) - return None - finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) - else: - print(f"File {tmp_path} does not exist") + out_path = os.path.join(CACHE_DIR, f"{sid}.wav") + if not os.path.exists(out_path): + os.makedirs(os.path.dirname(out_path), exist_ok=True) + sf.write(out_path, audio_array, sr) - transcription_time = time.time() - start - return reference, transcription, audio_duration, transcription_time + audio_paths.append(out_path) + durations.append(len(audio_array) / sr) - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_sample = { - executor.submit(process_sample, sample): sample for sample in ds - } - for future in tqdm( - concurrent.futures.as_completed(future_to_sample), - total=len(future_to_sample), - desc="Transcribing", - ): - result = future.result() - if result: - reference, transcription, audio_duration, transcription_time = result - results["predictions"].append(transcription) - results["references"].append(reference) - results["audio_length_s"].append(audio_duration) - results["transcription_time_s"].append(transcription_time) + # Normalize reference text + try: + ref_text = data_utils.get_text(sample) + except Exception: + ref_text = sample.get("text", " ") + references.append(data_utils.normalizer(ref_text) or " ") + + if args.max_eval_samples is not None and len(audio_paths) >= args.max_eval_samples: + break + + # Transcribe + predictions = [] + transcription_times = [] + print(f"Transcribing with model: {args.model_id}") + for audio_path in tqdm(audio_paths, desc="Transcribing"): + start = time.time() + try: + pred_text = transcribe_with_retry(args.model_id, audio_path) + except Exception as e: + print(f"Failed to transcribe {audio_path}: {e}") + pred_text = " " + elapsed = time.time() - start + transcription_times.append(elapsed) + predictions.append(data_utils.normalizer(pred_text) or " ") + time.sleep(0.1) - results["predictions"] = [ - data_utils.normalizer(transcription) or " " - for transcription in results["predictions"] - ] - results["references"] = [ - data_utils.normalizer(reference) or " " for reference in results["references"] - ] + if len(predictions) == 0: + print("No samples were successfully processed.") + return manifest_path = data_utils.write_manifest( - results["references"], - results["predictions"], - model_name.replace("/", "-"), - dataset_path, - dataset, - split, - audio_length=results["audio_length_s"], - transcription_time=results["transcription_time_s"], + references, + predictions, + args.model_id, + args.dataset_path, + args.dataset, + args.split, + audio_length=durations, + transcription_time=transcription_times, ) - - print("Results saved at path:", manifest_path) + print("Results saved at path:", os.path.abspath(manifest_path)) wer_metric = evaluate.load("wer") - wer = wer_metric.compute( - references=results["references"], predictions=results["predictions"] - ) - wer_percent = round(100 * wer, 2) - rtfx = round( - sum(results["audio_length_s"]) / sum(results["transcription_time_s"]), 2 - ) - - print("WER:", wer_percent, "%") + wer = wer_metric.compute(references=references, predictions=predictions) + wer = round(100 * wer, 2) + rtfx = round(sum(durations) / max(1e-9, sum(transcription_times)), 2) + print("WER:", wer, "%") print("RTFx:", rtfx) if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Unified Transcription Script with Concurrency" + parser = argparse.ArgumentParser(description="Gemini ASR Evaluation Script") + parser.add_argument( + "--model_id", + type=str, + required=True, + help="Model identifier, must start with 'gemini/'", + ) + parser.add_argument( + "--dataset_path", + type=str, + default="esb/datasets", + help="Dataset path. By default, it is `esb/datasets`", ) - parser.add_argument("--dataset_path", required=True) - parser.add_argument("--dataset", required=True) - parser.add_argument("--split", default="test") parser.add_argument( - "--model_name", + "--dataset", + type=str, required=True, - help="Prefix model name with 'gemini/'", + help="Dataset name.", ) - parser.add_argument("--max_samples", type=int, default=None) parser.add_argument( - "--max_workers", type=int, default=300, help="Number of concurrent threads" + "--split", + type=str, + default="test", + help="Split of the dataset.", ) parser.add_argument( - "--use_url", - action="store_true", - help="Use URL-based audio fetching instead of datasets", + "--batch_size", + type=int, + default=8, + help="Number of samples per streamed batch.", ) + parser.add_argument( + "--max_eval_samples", + type=int, + default=None, + help="Number of samples to be evaluated.", + ) + parser.add_argument( + "--no-streaming", + dest="streaming", + action="store_false", + help="Download the entire dataset instead of streaming.", + ) + parser.set_defaults(streaming=True) args = parser.parse_args() - gemini_api_key = getpass.getpass("Enter your Gemini API key: ") - genai.configure(api_key=gemini_api_key) - - transcribe_dataset( - dataset_path=args.dataset_path, - dataset=args.dataset, - split=args.split, - model_name=args.model_name, - use_url=args.use_url, - max_samples=args.max_samples, - max_workers=args.max_workers, - ) + api_key = os.getenv("GOOGLE_API_KEY") + if not api_key: + try: + api_key = getpass.getpass("Enter your Gemini API key: ") + except Exception: + api_key = None + if not api_key: + raise RuntimeError("GOOGLE_API_KEY not set and no key provided interactively.") + genai.configure(api_key=api_key) + + main(args) diff --git a/gemini/run_eval_ml.py b/gemini/run_eval_ml.py index 606227f..bc14b47 100644 --- a/gemini/run_eval_ml.py +++ b/gemini/run_eval_ml.py @@ -1,32 +1,38 @@ import argparse from typing import Optional import datasets +from datasets import Dataset import evaluate import soundfile as sf -import tempfile import time +import tempfile import os -import requests +import io +import numpy as np +import pandas as pd from tqdm import tqdm from normalizer import data_utils -import concurrent.futures import getpass -import google.generativeai as genai +try: + import google.generativeai as genai +except ImportError: + print("Error: google-generativeai not installed. Run: pip install google-generativeai") + exit(1) def transcribe_with_retry( - model_name: str, - audio_file_path: Optional[str], - sample: dict, + model_id: str, + audio_file_path: str, + sample: Optional[dict] = None, max_retries=10, use_url=False, # This is not used in the ml script, but we keep it for consistency with the function signature -): +) -> str: retries = 0 while retries <= max_retries: try: - if model_name.startswith("gemini/"): - model_id = model_name.split("/", 1)[1] - model = genai.GenerativeModel(model_id) + if model_id.startswith("gemini/"): + _model = model_id.split("/", 1)[1] + model = genai.GenerativeModel(_model) # In the multilingual script, we always have a local file # so we don't need to handle the use_url case for downloading. @@ -40,7 +46,7 @@ def transcribe_with_retry( # Clean up the uploaded file genai.delete_file(gemini_file.name) - return response.text.strip() + return response.text.strip() if response.text else "" else: raise ValueError( "Invalid model prefix, must start with 'gemini/'" @@ -52,11 +58,14 @@ def transcribe_with_retry( raise e # Re-writing the file on failure is handled by the caller. - delay = 1 + delay = min(2 ** retries, 30) # Exponential backoff with max 30s print( f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})" ) time.sleep(delay) + + # This should never be reached, but adding for type safety + return "" def transcribe_dataset( @@ -64,22 +73,44 @@ def transcribe_dataset( config_name, language, split, - model_name, - max_samples, + model_id, + max_eval_samples, max_workers, streaming, ): print(f"Loading dataset: {dataset} with config: {config_name}") - ds = datasets.load_dataset(dataset, config_name, split=split, streaming=streaming) - - if max_samples is not None and max_samples > 0: - print(f"Subsampling dataset to first {max_samples} samples!") - if streaming: - ds = ds.take(max_samples) + # Force non-streaming to avoid torchcodec issues + ds = datasets.load_dataset(dataset, config_name, split=split, streaming=False) + + # Apply subsampling first, then convert to pandas + if max_eval_samples is not None and max_eval_samples > 0: + print(f"Subsampling dataset to first {max_eval_samples} samples...") + if hasattr(ds, 'select') and hasattr(ds, '__len__'): + ds = ds.select(range(min(max_eval_samples, len(ds)))) else: - ds = ds.select(range(min(max_samples, len(ds)))) - + # Fallback: convert to list and slice + dataset_list = list(ds) + dataset_list = dataset_list[:max_eval_samples] + ds = Dataset.from_list(dataset_list) + + # Convert to pandas DataFrame to avoid audio decoding during iteration + print("Converting to pandas DataFrame...") + try: + df = ds.to_pandas() + print(f"Successfully loaded {len(df)} samples") + print(f"Dataset columns: {df.columns.tolist()}") + except Exception as e: + print(f"Error converting to pandas: {e}") + # Fallback to list conversion + dataset_list = [] + for i, sample in enumerate(ds): + if max_eval_samples and i >= max_eval_samples: + break + dataset_list.append(sample) + df = pd.DataFrame(dataset_list) + print(f"Successfully loaded {len(df)} samples via fallback") + print(f"Dataset columns: {df.columns.tolist()}") results = { "references": [], @@ -88,57 +119,83 @@ def transcribe_dataset( "transcription_time_s": [], } - print(f"Transcribing with model: {model_name}") - - def process_sample(sample): - reference = sample.get("text", "").strip() or " " - - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: - sf.write( - tmpfile.name, - sample["audio"]["array"], - sample["audio"]["sampling_rate"], - format="WAV", - ) - tmp_path = tmpfile.name - audio_duration = ( - len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"] - ) - - start = time.time() + print(f"Transcribing with model: {model_id}") + print(f"Processing samples...") + + # Process samples sequentially using pandas approach + for idx, row in tqdm(df.iterrows(), total=len(df), desc="Transcribing"): try: - transcription = transcribe_with_retry( - model_name, tmp_path, sample - ) - except Exception as e: - print(f"Failed to transcribe after retries: {e}") - os.unlink(tmp_path) - return None - finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) + reference = str(row.get("text", "")).strip() or " " + + # Handle audio data similar to main script + audio_data = row.get('audio') + + if audio_data is None: + print(f"Skipping sample {idx} - no audio data") + continue + + # Handle different audio formats from pandas + if isinstance(audio_data, dict): + if 'bytes' in audio_data and audio_data['bytes'] is not None: + # Handle bytes format first (most reliable) + try: + with io.BytesIO(audio_data['bytes']) as audio_file: + audio_array, sample_rate = sf.read(audio_file, dtype="float32") + except Exception as e: + print(f"Error loading audio from bytes: {e}") + continue + elif 'array' in audio_data and audio_data['array'] is not None: + audio_array = np.array(audio_data['array'], dtype=np.float32) + sample_rate = audio_data.get('sampling_rate', 16000) + elif 'path' in audio_data and audio_data['path'] is not None: + # Load from file path (last resort) + try: + audio_array, sample_rate = sf.read(audio_data['path'], dtype="float32") + except Exception as e: + print(f"Error loading audio from path {audio_data['path']}: {e}") + continue + else: + print(f"Skipping sample {idx} - unsupported audio format. Available keys: {list(audio_data.keys())}") + continue else: - print(f"File {tmp_path} does not exist") - - transcription_time = time.time() - start - return reference, transcription, audio_duration, transcription_time - - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_sample = { - executor.submit(process_sample, sample): sample for sample in ds - } - for future in tqdm( - concurrent.futures.as_completed(future_to_sample), - total=len(future_to_sample), - desc="Transcribing", - ): - result = future.result() - if result: - reference, transcription, audio_duration, transcription_time = result - results["predictions"].append(transcription) - results["references"].append(reference) - results["audio_length_s"].append(audio_duration) - results["transcription_time_s"].append(transcription_time) + print(f"Skipping sample {idx} - unexpected audio format: {type(audio_data)}") + continue + + try: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: + sf.write( + tmpfile.name, + audio_array, + sample_rate, + format="WAV", + ) + tmp_path = tmpfile.name + audio_duration = len(audio_array) / sample_rate + + start = time.time() + try: + transcription = transcribe_with_retry(model_id, tmp_path, row) + if transcription: + results["predictions"].append(transcription) + results["references"].append(reference) + results["audio_length_s"].append(audio_duration) + transcription_time = time.time() - start + results["transcription_time_s"].append(transcription_time) + print(f"Sample {idx+1}: Transcribed successfully") + except Exception as e: + print(f"Failed to transcribe sample {idx} after retries: {e}") + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + # Add small delay to avoid rate limiting + time.sleep(0.1) + except Exception as e: + print(f"Error processing sample {idx}: {e}") + continue + except Exception as e: + print(f"Error processing sample {idx}: {e}") + continue if language == "en": results["predictions"] = [ @@ -161,7 +218,7 @@ def process_sample(sample): manifest_path = data_utils.write_manifest( results["references"], results["predictions"], - model_name.replace("/", "-"), + model_id, dataset, config_name, split, @@ -171,24 +228,33 @@ def process_sample(sample): print("Results saved at path:", manifest_path) - wer_metric = evaluate.load("wer") - wer = wer_metric.compute( - references=results["references"], predictions=results["predictions"] - ) - wer_percent = round(100 * wer, 2) - rtfx = round( - sum(results["audio_length_s"]) / sum(results["transcription_time_s"]), 2 - ) - - print("WER:", wer_percent, "%") - print("RTFx:", rtfx) + # Only compute WER if we have samples + if len(results["predictions"]) > 0 and len(results["references"]) > 0: + wer_metric = evaluate.load("wer") + wer_result = wer_metric.compute( + references=results["references"], predictions=results["predictions"] + ) + wer_value = wer_result if isinstance(wer_result, (int, float)) else wer_result.get('wer', 0.0) + wer_percent = round(100 * wer_value, 2) + total_transcription_time = sum(results["transcription_time_s"]) + rtfx = round( + sum(results["audio_length_s"]) / total_transcription_time, 2 + ) if total_transcription_time > 0 else 0 + + print("WER:", wer_percent, "%") + print("RTFx:", rtfx) + else: + print("No samples were successfully processed - cannot compute WER") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--model_name", type=str, required=True, help="Model identifier. Should be loadable with Gemini.", + "--model_id", + type=str, + required=True, + help="Model identifier, must start with 'gemini/'", ) parser.add_argument( "--dataset", @@ -215,7 +281,7 @@ def process_sample(sample): help="Split of the dataset. Default is 'test'.", ) parser.add_argument( - "--max_samples", + "--max_eval_samples", type=int, default=None, help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", @@ -229,8 +295,8 @@ def process_sample(sample): action="store_false", help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", ) - args = parser.parse_args() parser.set_defaults(streaming=True) + args = parser.parse_args() if args.language is None: try: @@ -240,7 +306,9 @@ def process_sample(sample): print(f"Detected language: {args.language}") - gemini_api_key = getpass.getpass("Enter your Gemini API key: ") + gemini_api_key = os.getenv("GOOGLE_API_KEY") + if not gemini_api_key: + gemini_api_key = getpass.getpass("Enter your Gemini API key: ") genai.configure(api_key=gemini_api_key) transcribe_dataset( @@ -248,8 +316,8 @@ def process_sample(sample): config_name=args.config_name, language=args.language, split=args.split, - model_name=args.model_name, - max_samples=args.max_samples, + model_id=args.model_id, + max_eval_samples=args.max_eval_samples, max_workers=args.max_workers, streaming=args.streaming, ) diff --git a/gemini/run_gemini.sh b/gemini/run_gemini.sh index 8e085e7..71e6ed5 100644 --- a/gemini/run_gemini.sh +++ b/gemini/run_gemini.sh @@ -1,73 +1,89 @@ #!/bin/bash -export PYTHONPATH="..":$PYTHONPATH +# Set PYTHONPATH to include parent directory for proper imports +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +export PYTHONPATH="${SCRIPT_DIR}/..:$PYTHONPATH" + +# Load environment variables from .env if present (not required but convenient) +if [ -f "${SCRIPT_DIR}/.env" ]; then + set -a + # shellcheck disable=SC1090 + . "${SCRIPT_DIR}/.env" + set +a +fi + +# Use full Python path for reliability (can be overridden with PYTHON_CMD environment variable) +PYTHON_CMD=${PYTHON_CMD:-"C:\Users\Alexander.Immler\AppData\Local\Programs\Python\Python313\python.exe"} + +# Check if Google API key is set +if [ -z "$GOOGLE_API_KEY" ]; then + echo "Error: GOOGLE_API_KEY environment variable not set" + echo "Please set it with: export GOOGLE_API_KEY='your_api_key_here'" + exit 1 +fi + +# Verify Python installation +if ! command -v "$PYTHON_CMD" &> /dev/null; then + echo "Warning: Python not found at $PYTHON_CMD" + echo "Falling back to system python3" + PYTHON_CMD="python3" +fi MODEL_IDs=( "gemini/gemini-2.5-pro" "gemini/gemini-2.5-flash" ) +# Test with small samples first +TEST_SAMPLES=2 + for MODEL_ID in "${MODEL_IDs[@]}" do echo "--- Running Benchmarks for $MODEL_ID ---" - - # --- English Benchmarks --- - echo "--- Running English Benchmarks ---" - python gemini/run_eval.py \ + echo "Using Python: $PYTHON_CMD" + + # Test one sample first to verify setup + echo "--- Testing setup with AMI dataset ---" + if ! "$PYTHON_CMD" run_eval.py \ --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ --dataset="ami" \ --split="test" \ - --model_name ${MODEL_ID} \ - --max_samples 2 - - python gemini/run_eval.py \ - --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ - --dataset="earnings22" \ - --split="test" \ - --model_name ${MODEL_ID} \ - --max_samples 2 + --model_id "${MODEL_ID}" \ + --max_eval_samples 1; then + echo "Error: Failed to run test evaluation for $MODEL_ID" + echo "Skipping this model..." + continue + fi + + echo "--- Setup verified, continuing with full English benchmarks ---" - python gemini/run_eval.py \ - --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ - --dataset="gigaspeech" \ - --split="test" \ - --model_name ${MODEL_ID} \ - --max_samples 2 - - python gemini/run_eval.py \ - --dataset_path "hf-audio/esb-datasets-test-only-sorted" \ - --dataset "librispeech" \ - --split "test.clean" \ - --model_name ${MODEL_ID} \ - --max_samples 2 - - python gemini/run_eval.py \ - --dataset_path "hf-audio/esb-datasets-test-only-sorted" \ - --dataset "librispeech" \ - --split "test.other" \ - --model_name ${MODEL_ID} \ - --max_samples 2 - - python gemini/run_eval.py \ - --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ - --dataset="spgispeech" \ - --split="test" \ - --model_name ${MODEL_ID} \ - --max_samples 2 - - python gemini/run_eval.py \ - --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ - --dataset="tedlium" \ - --split="test" \ - --model_name ${MODEL_ID} \ - --max_samples 2 - - python gemini/run_eval.py \ - --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ - --dataset="voxpopuli" \ - --split="test" \ - --model_name ${MODEL_ID} \ - --max_samples 2 + # --- English Benchmarks --- + echo "--- Running English Benchmarks ---" + + for dataset in "earnings22" "gigaspeech" "librispeech" "spgispeech" "tedlium" "voxpopuli"; do + echo "Processing dataset: $dataset" + + if [ "$dataset" = "librispeech" ]; then + # LibriSpeech has multiple splits + for split in "test.clean" "test.other"; do + echo "Running $dataset/$split..." + "$PYTHON_CMD" run_eval.py \ + --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ + --dataset="$dataset" \ + --split="$split" \ + --model_id "${MODEL_ID}" \ + --max_eval_samples $TEST_SAMPLES || echo "Warning: Failed to process $dataset/$split" + done + else + echo "Running $dataset..." + "$PYTHON_CMD" run_eval.py \ + --dataset_path="hf-audio/esb-datasets-test-only-sorted" \ + --dataset="$dataset" \ + --split="test" \ + --model_id "${MODEL_ID}" \ + --max_eval_samples $TEST_SAMPLES || echo "Warning: Failed to process $dataset" + fi + done # --- Multilingual Benchmarks --- echo "--- Running Multilingual Benchmarks ---" @@ -77,26 +93,31 @@ do EVAL_DATASETS["mls"]="es fr it pt" for dataset in ${!EVAL_DATASETS[@]}; do + echo "Processing multilingual dataset: $dataset" for language in ${EVAL_DATASETS[$dataset]}; do config_name="${dataset}_${language}" echo "Running evaluation for $config_name" - python gemini/run_eval_ml.py \ - --model_name="$MODEL_ID" \ + "$PYTHON_CMD" run_eval_ml.py \ + --model_id="$MODEL_ID" \ --dataset="nithinraok/asr-leaderboard-datasets" \ --config_name="$config_name" \ --language="$language" \ --split="test" \ - --max_samples 2 + --max_eval_samples $TEST_SAMPLES || echo "Warning: Failed to process $config_name" done done # --- Scoring --- echo "--- Scoring results for $MODEL_ID ---" RUNDIR=$(pwd) - cd normalizer && \ - python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \ - cd $RUNDIR - + if cd ../normalizer; then + "$PYTHON_CMD" -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" || echo "Warning: Scoring failed for $MODEL_ID" + cd "$RUNDIR" || echo "Warning: Could not return to original directory" + else + echo "Warning: Could not access normalizer directory for scoring" + fi + + echo "--- Completed benchmarks for $MODEL_ID ---" done echo "--- All benchmarks complete ---" From 6c9d8abb1a69fabcae543199589f6fb09e49b1db Mon Sep 17 00:00:00 2001 From: Alexander Immler Date: Fri, 22 Aug 2025 12:13:24 +0200 Subject: [PATCH 3/3] gemini: docs + scripts cleanup; generic Python resolution, .env auto-load; updated README with full workflow --- gemini/README.md | 124 ++++++++++++++++++++++++++++++++++++++ gemini/run_gemini.ps1 | 135 ++++++++++++++++++++++++++++++++++++++++++ gemini/run_gemini.sh | 13 ++-- gemini/run_temp.ps1 | 37 ++++++++++++ 4 files changed, 305 insertions(+), 4 deletions(-) create mode 100644 gemini/README.md create mode 100644 gemini/run_gemini.ps1 create mode 100644 gemini/run_temp.ps1 diff --git a/gemini/README.md b/gemini/README.md new file mode 100644 index 0000000..5e78951 --- /dev/null +++ b/gemini/README.md @@ -0,0 +1,124 @@ +# Gemini ASR Evaluation + +This folder contains a self-contained flow to evaluate Google Gemini models on the Open ASR Leaderboard datasets. It aligns with the repository’s template: consistent dataset loading, normalization, manifest writing, WER/RTFx computation, and scripts to run and score. + +## Quick Start + +1) Install dependencies (from repo root or ensure they’re available): + +```bash +# From repo root +pip install -r ../requirements/requirements.txt + +# From gemini/ (adds Gemini client) +pip install -r requirements_gemini.txt +``` + +2) Provide your Gemini API key: + +```bash +# Option A (recommended): Put it in gemini/.env +echo "GOOGLE_API_KEY=your_api_key_here" > .env + +# Option B: Export it in your shell for this session +export GOOGLE_API_KEY=your_api_key_here # bash/zsh +# or +$env:GOOGLE_API_KEY = "your_api_key_here" # PowerShell +``` + +3) Ensure Python can import the repo’s modules. Scripts set `PYTHONPATH` automatically. For manual runs from gemini/: + +```bash +export PYTHONPATH="$(pwd)/.." # bash/zsh +# or +$env:PYTHONPATH = ".." # PowerShell +``` + +## How It Works + +- Data loading: For English, audio is accessed without automatic decoding to avoid `torchcodec`; audio bytes/paths are read via `soundfile` and cached under `gemini/audio_cache//`. Multilingual follows the shared pattern. +- Normalization: References and predictions are normalized (English: `EnglishTextNormalizer`; multilingual: `BasicMultilingualTextNormalizer`). +- Transcription: Each audio file is uploaded to the Gemini API, transcribed with retries + exponential backoff, then cleaned up. +- Outputs: Each run writes a JSONL manifest under `gemini/results/` and prints WER and RTFx. +- Scoring: `normalizer/eval_utils.score_results` aggregates across results and prints per-dataset and composite metrics. + +## Run Individual Evaluations + +English (run_eval.py): + +```bash +python run_eval.py \ + --model_id "gemini/gemini-2.5-pro" \ + --dataset_path "hf-audio/esb-datasets-test-only-sorted" \ + --dataset "ami" \ + --split "test" \ + --max_eval_samples 2 +``` + +Multilingual (run_eval_ml.py): + +```bash +python run_eval_ml.py \ + --model_id "gemini/gemini-2.5-pro" \ + --dataset "nithinraok/asr-leaderboard-datasets" \ + --config_name "fleurs_en" \ + --language "en" \ + --split "test" \ + --max_eval_samples 2 +``` + +Notes: +- `--model_id` must start with `gemini/` (e.g., `gemini/gemini-2.5-pro`, `gemini/gemini-2.5-flash`). +- English script loads audio offline via bytes/path—no `torchcodec` required. + +## Run Full Benchmark Suite + +Both scripts resolve Python from PATH automatically; you can override with `PYTHON_CMD`. + +- Bash (Linux/macOS): +```bash +chmod +x run_gemini.sh +./run_gemini.sh +``` + +- PowerShell (Windows): +```powershell +./run_gemini.ps1 +``` + +Behavior: +- Auto-loads `gemini/.env` if present (so you don’t need to export `GOOGLE_API_KEY` manually). +- Sets `PYTHONPATH` to the repo root automatically. +- Runs a short smoke test first, then loops through core English datasets (and multilingual configs) with a small sample size for validation. Adjust sample sizes and datasets in the scripts as needed. + +## Scoring Results + +Score all manifests under `gemini/results/` for a given model id: + +```bash +python -c "import normalizer.eval_utils as e; e.score_results('gemini/results', 'gemini/gemini-2.5-pro')" +``` + +This prints per-dataset WER and RTFx and a composite WER/RTFx by model. + +## Environment Variables + +- `GOOGLE_API_KEY` (required): Gemini API key. Set via `.env` or your shell. +- `PYTHONPATH`: Path to the repo root. Scripts set this automatically; for manual runs set it to `..` from inside `gemini/`. +- `PYTHON_CMD` (optional): Override which Python to use in the scripts (e.g., `PYTHON_CMD=/path/to/python`). +- `HF_TOKEN` (optional): Hugging Face token (only needed for private datasets). + +## Troubleshooting + +- Missing packages: Install both the repo requirements and `requirements_gemini.txt`. +- API key errors: Ensure `GOOGLE_API_KEY` is set. Scripts read `.env` automatically. +- Exec permissions (Linux/macOS): `chmod +x run_gemini.sh`. +- Torchcodec errors: English script reads audio from bytes/paths with `soundfile` and does not require `torchcodec`. + +## Files + +- `run_eval.py`: English evaluation script (Gemini transcription + WER/RTFx + manifest writing). +- `run_eval_ml.py`: Multilingual evaluation script. +- `run_gemini.sh`/`run_gemini.ps1`: Batch runners (auto-load `.env`, resolve Python, set `PYTHONPATH`). +- `requirements_gemini.txt`: Gemini client dependency. +- `audio_cache/`, `results/`: Local outputs (cached audio and manifests). diff --git a/gemini/run_gemini.ps1 b/gemini/run_gemini.ps1 new file mode 100644 index 0000000..88a9437 --- /dev/null +++ b/gemini/run_gemini.ps1 @@ -0,0 +1,135 @@ +# PowerShell version of run_gemini.sh (clean) + +# Set environment variables +$env:PYTHONPATH = ".." + +# Load .env if present to populate GOOGLE_API_KEY and others +try { + $envFile = Join-Path $PSScriptRoot ".env" + if (Test-Path $envFile) { + Get-Content -Path $envFile | ForEach-Object { + $line = $_.Trim() + if ($line -and -not $line.StartsWith('#')) { + if ($line -match '^(?[^#=]+?)\s*=\s*(?.*)$') { + $k = $Matches['k'].Trim() + $v = $Matches['v'].Trim() + if ($k) { Set-Item -Path Env:$k -Value $v | Out-Null } + } + } + } + } +} catch {} + +# Check if Google API key is set +if (-not $env:GOOGLE_API_KEY) { + Write-Host "Error: GOOGLE_API_KEY environment variable not set" -ForegroundColor Red + Write-Host "Please set it with: `$env:GOOGLE_API_KEY='your_api_key_here'" -ForegroundColor Yellow + exit 1 +} + +# Resolve Python from PATH (override by setting $env:PYTHON_CMD) +$PythonExe = $env:PYTHON_CMD +if (-not $PythonExe) { $PythonExe = (Get-Command python -ErrorAction SilentlyContinue)?.Source } +if (-not $PythonExe) { $PythonExe = (Get-Command py -ErrorAction SilentlyContinue)?.Source } +if (-not $PythonExe) { $PythonExe = "python" } +Write-Host "Using Python: $PythonExe" -ForegroundColor Green + +$MODEL_IDs = @( + "gemini/gemini-2.5-pro", + "gemini/gemini-2.5-flash" +) + +$BATCH_SIZE = 8 +$MAX_SAMPLES = 2 +Write-Host "Test samples per dataset: $MAX_SAMPLES" -ForegroundColor Green + +foreach ($MODEL_ID in $MODEL_IDs) { + Write-Host "--- Running Benchmarks for $MODEL_ID ---" -ForegroundColor Cyan + + # Quick setup test + Write-Host "--- Testing setup with AMI dataset ---" -ForegroundColor Yellow + try { + & $PythonExe run_eval.py ` + --model_id="$MODEL_ID" ` + --dataset_path="hf-audio/esb-datasets-test-only-sorted" ` + --dataset="ami" ` + --split="test" ` + --max_eval_samples=1 + Write-Host "✓ Setup verified, continuing with benchmarks" -ForegroundColor Green + } catch { + Write-Host "✗ Setup test failed for $MODEL_ID, skipping..." -ForegroundColor Red + continue + } + + # English benchmarks + Write-Host "--- Running English Benchmarks ---" -ForegroundColor Cyan + $datasets = @("ami", "earnings22", "gigaspeech", "librispeech", "spgispeech", "tedlium", "voxpopuli") + foreach ($dataset in $datasets) { + Write-Host "Processing dataset: $dataset" -ForegroundColor Yellow + if ($dataset -eq "librispeech") { + foreach ($split in @("test.clean", "test.other")) { + Write-Host "Running $dataset/$split..." -ForegroundColor White + try { + & $PythonExe run_eval.py ` + --model_id="$MODEL_ID" ` + --dataset_path="hf-audio/esb-datasets-test-only-sorted" ` + --dataset="$dataset" ` + --split="$split" ` + --batch_size=$BATCH_SIZE ` + --max_eval_samples=$MAX_SAMPLES + Write-Host "✓ Completed $dataset/$split" -ForegroundColor Green + } catch { + Write-Host "✗ Warning: Failed to process $dataset/$split" -ForegroundColor Red + } + } + } else { + Write-Host "Running $dataset..." -ForegroundColor White + try { + & $PythonExe run_eval.py ` + --model_id="$MODEL_ID" ` + --dataset_path="hf-audio/esb-datasets-test-only-sorted" ` + --dataset="$dataset" ` + --split="test" ` + --batch_size=$BATCH_SIZE ` + --max_eval_samples=$MAX_SAMPLES + Write-Host "✓ Completed $dataset" -ForegroundColor Green + } catch { + Write-Host "✗ Warning: Failed to process $dataset" -ForegroundColor Red + } + } + } + + # Multilingual + Write-Host "--- Running Multilingual Benchmarks ---" -ForegroundColor Cyan + $EVAL_DATASETS = @{ + "fleurs" = @("en", "de", "fr", "it", "es", "pt") + "mcv" = @("en", "de", "es", "fr", "it") + "mls" = @("es", "fr", "it", "pt") + } + foreach ($ds in $EVAL_DATASETS.Keys) { + foreach ($lang in $EVAL_DATASETS[$ds]) { + $config = "${ds}_${lang}" + Write-Host "Running evaluation for $config" -ForegroundColor White + try { + & $PythonExe run_eval_ml.py ` + --model_id="$MODEL_ID" ` + --dataset="nithinraok/asr-leaderboard-datasets" ` + --config_name="$config" ` + --language="$lang" ` + --split="test" ` + --max_eval_samples=$MAX_SAMPLES + } catch { + Write-Host "✗ Warning: Failed to process $config" -ForegroundColor Red + } + } + } + + # Scoring + Write-Host "--- Scoring results for $MODEL_ID ---" -ForegroundColor Cyan + $RUNDIR = Get-Location + Set-Location "..\normalizer" + & $PythonExe -c "import eval_utils; eval_utils.score_results('$RUNDIR\results', '$MODEL_ID')" + Set-Location $RUNDIR +} + +Write-Host "--- All benchmarks complete ---" diff --git a/gemini/run_gemini.sh b/gemini/run_gemini.sh index 71e6ed5..90b4c80 100644 --- a/gemini/run_gemini.sh +++ b/gemini/run_gemini.sh @@ -12,8 +12,8 @@ if [ -f "${SCRIPT_DIR}/.env" ]; then set +a fi -# Use full Python path for reliability (can be overridden with PYTHON_CMD environment variable) -PYTHON_CMD=${PYTHON_CMD:-"C:\Users\Alexander.Immler\AppData\Local\Programs\Python\Python313\python.exe"} +# Use Python from PATH by default (override with PYTHON_CMD if desired) +PYTHON_CMD=${PYTHON_CMD:-python3} # Check if Google API key is set if [ -z "$GOOGLE_API_KEY" ]; then @@ -25,8 +25,13 @@ fi # Verify Python installation if ! command -v "$PYTHON_CMD" &> /dev/null; then echo "Warning: Python not found at $PYTHON_CMD" - echo "Falling back to system python3" - PYTHON_CMD="python3" + if command -v python3 &> /dev/null; then + echo "Falling back to python3" + PYTHON_CMD="python3" + else + echo "Falling back to python" + PYTHON_CMD="python" + fi fi MODEL_IDs=( diff --git a/gemini/run_temp.ps1 b/gemini/run_temp.ps1 new file mode 100644 index 0000000..5656f7e --- /dev/null +++ b/gemini/run_temp.ps1 @@ -0,0 +1,37 @@ +# Set environment variables +$env:GOOGLE_API_KEY = "" +$env:PYTHONPATH = (Resolve-Path (Join-Path $PSScriptRoot '..')).Path + +# Change to gemini directory +Set-Location $PSScriptRoot + +# Resolve Python from PATH or use env override +$PythonExe = $env:PYTHON_CMD +if (-not $PythonExe) { $PythonExe = (Get-Command python -ErrorAction SilentlyContinue)?.Source } +if (-not $PythonExe) { $PythonExe = (Get-Command py -ErrorAction SilentlyContinue)?.Source } +if (-not $PythonExe) { $PythonExe = "python" } + +Write-Host "Using Python: $PythonExe" -ForegroundColor Green +Write-Host "Running on ENTIRE datasets (no sample limit)" -ForegroundColor Green + +# Run full evaluation on AMI dataset +Write-Host "--- Running full AMI dataset evaluation ---" -ForegroundColor Cyan +try { + & $PythonExe run_eval.py --model_name="gemini/gemini-2.5-flash" --dataset_path="hf-audio/esb-datasets-test-only-sorted" --dataset="ami" --split="test" + Write-Host "✓ AMI evaluation completed successfully" -ForegroundColor Green +} catch { + Write-Host "✗ AMI evaluation failed: $($_.Exception.Message)" -ForegroundColor Red + exit 1 +} + +# Run full multilingual evaluation +Write-Host "--- Running full multilingual evaluation (FLEURS English) ---" -ForegroundColor Cyan +try { + & $PythonExe run_eval_ml.py --model_name="gemini/gemini-2.5-flash" --dataset="nithinraok/asr-leaderboard-datasets" --config_name="fleurs_en" --language="en" --split="test" + Write-Host "✓ Multilingual evaluation completed successfully" -ForegroundColor Green +} catch { + Write-Host "✗ Multilingual evaluation failed: $($_.Exception.Message)" -ForegroundColor Red + exit 1 +} + +Write-Host "--- All full dataset evaluations completed successfully ---" -ForegroundColor Green