Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions gemini/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Gemini ASR Evaluation

This folder contains a self-contained flow to evaluate Google Gemini models on the Open ASR Leaderboard datasets. It aligns with the repository�s template: consistent dataset loading, normalization, manifest writing, WER/RTFx computation, and scripts to run and score.

## Quick Start

1) Install dependencies (from repo root or ensure they�re available):

```bash
# From repo root
pip install -r ../requirements/requirements.txt

# From gemini/ (adds Gemini client)
pip install -r requirements_gemini.txt
```

2) Provide your Gemini API key:

```bash
# Option A (recommended): Put it in gemini/.env
echo "GOOGLE_API_KEY=your_api_key_here" > .env

# Option B: Export it in your shell for this session
export GOOGLE_API_KEY=your_api_key_here # bash/zsh
# or
$env:GOOGLE_API_KEY = "your_api_key_here" # PowerShell
```

3) Ensure Python can import the repo�s modules. Scripts set `PYTHONPATH` automatically. For manual runs from gemini/:

```bash
export PYTHONPATH="$(pwd)/.." # bash/zsh
# or
$env:PYTHONPATH = ".." # PowerShell
```

## How It Works

- Data loading: For English, audio is accessed without automatic decoding to avoid `torchcodec`; audio bytes/paths are read via `soundfile` and cached under `gemini/audio_cache/<dataset>/<split>`. Multilingual follows the shared pattern.
- Normalization: References and predictions are normalized (English: `EnglishTextNormalizer`; multilingual: `BasicMultilingualTextNormalizer`).
- Transcription: Each audio file is uploaded to the Gemini API, transcribed with retries + exponential backoff, then cleaned up.
- Outputs: Each run writes a JSONL manifest under `gemini/results/` and prints WER and RTFx.
- Scoring: `normalizer/eval_utils.score_results` aggregates across results and prints per-dataset and composite metrics.

## Run Individual Evaluations

English (run_eval.py):

```bash
python run_eval.py \
--model_id "gemini/gemini-2.5-pro" \
--dataset_path "hf-audio/esb-datasets-test-only-sorted" \
--dataset "ami" \
--split "test" \
--max_eval_samples 2
```

Multilingual (run_eval_ml.py):

```bash
python run_eval_ml.py \
--model_id "gemini/gemini-2.5-pro" \
--dataset "nithinraok/asr-leaderboard-datasets" \
--config_name "fleurs_en" \
--language "en" \
--split "test" \
--max_eval_samples 2
```

Notes:
- `--model_id` must start with `gemini/` (e.g., `gemini/gemini-2.5-pro`, `gemini/gemini-2.5-flash`).
- English script loads audio offline via bytes/path�no `torchcodec` required.

## Run Full Benchmark Suite

Both scripts resolve Python from PATH automatically; you can override with `PYTHON_CMD`.

- Bash (Linux/macOS):
```bash
chmod +x run_gemini.sh
./run_gemini.sh
```

- PowerShell (Windows):
```powershell
./run_gemini.ps1
```

Behavior:
- Auto-loads `gemini/.env` if present (so you don�t need to export `GOOGLE_API_KEY` manually).
- Sets `PYTHONPATH` to the repo root automatically.
- Runs a short smoke test first, then loops through core English datasets (and multilingual configs) with a small sample size for validation. Adjust sample sizes and datasets in the scripts as needed.

## Scoring Results

Score all manifests under `gemini/results/` for a given model id:

```bash
python -c "import normalizer.eval_utils as e; e.score_results('gemini/results', 'gemini/gemini-2.5-pro')"
```

This prints per-dataset WER and RTFx and a composite WER/RTFx by model.

## Environment Variables

- `GOOGLE_API_KEY` (required): Gemini API key. Set via `.env` or your shell.
- `PYTHONPATH`: Path to the repo root. Scripts set this automatically; for manual runs set it to `..` from inside `gemini/`.
- `PYTHON_CMD` (optional): Override which Python to use in the scripts (e.g., `PYTHON_CMD=/path/to/python`).
- `HF_TOKEN` (optional): Hugging Face token (only needed for private datasets).

## Troubleshooting

- Missing packages: Install both the repo requirements and `requirements_gemini.txt`.
- API key errors: Ensure `GOOGLE_API_KEY` is set. Scripts read `.env` automatically.
- Exec permissions (Linux/macOS): `chmod +x run_gemini.sh`.
- Torchcodec errors: English script reads audio from bytes/paths with `soundfile` and does not require `torchcodec`.

## Files

- `run_eval.py`: English evaluation script (Gemini transcription + WER/RTFx + manifest writing).
- `run_eval_ml.py`: Multilingual evaluation script.
- `run_gemini.sh`/`run_gemini.ps1`: Batch runners (auto-load `.env`, resolve Python, set `PYTHONPATH`).
- `requirements_gemini.txt`: Gemini client dependency.
- `audio_cache/`, `results/`: Local outputs (cached audio and manifests).
1 change: 1 addition & 0 deletions gemini/requirements_gemini.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
google-generativeai
228 changes: 228 additions & 0 deletions gemini/run_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import argparse
import io
import os
import time
import getpass

import evaluate
import datasets
import numpy as np
import soundfile as sf
from tqdm import tqdm

from normalizer import data_utils

try:
import google.generativeai as genai
except ImportError:
print("Error: google-generativeai not installed. Run: pip install google-generativeai")
exit(1)


def transcribe_with_retry(
model_id: str,
audio_file_path: str,
max_retries: int = 10,
) -> str:
retries = 0
while retries <= max_retries:
try:
if model_id.startswith("gemini/"):
_model = model_id.split("/", 1)[1]
model = genai.GenerativeModel(_model)

gemini_file = genai.upload_file(path=audio_file_path)
response = model.generate_content([
"Generate a transcript of the speech.",
gemini_file,
])
genai.delete_file(gemini_file.name)

return response.text.strip() if getattr(response, "text", None) else ""
else:
raise ValueError("Invalid model prefix, must start with 'gemini/'")

except Exception as e:
retries += 1
if retries > max_retries:
raise e

delay = min(2 ** retries, 30) # Exponential backoff with max 30s
print(
f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})"
)
time.sleep(delay)

# This should never be reached, but adding for type safety
return ""


def main(args):
DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache")
CACHE_DIR = os.path.join(DATA_CACHE_DIR, args.dataset, args.split)
os.makedirs(CACHE_DIR, exist_ok=True)

# Load dataset without triggering audio decoding (avoid torchcodec)
ds = datasets.load_dataset(
args.dataset_path,
args.dataset,
split=args.split,
streaming=False,
token=True,
)
# Keep audio as filepaths to avoid decoding here
try:
from datasets import Audio
ds = ds.cast_column("audio", Audio(decode=False))
except Exception:
pass

# Subsample
if args.max_eval_samples is not None and args.max_eval_samples > 0:
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
if hasattr(ds, "select") and hasattr(ds, "__len__"):
ds = ds.select(range(min(args.max_eval_samples, len(ds))))

references = []
audio_paths = []
durations = []

for sample in tqdm(ds, desc="Preparing samples"):
sid = str(sample.get("id", "sample")).replace("/", "_").removesuffix(".wav")
audio_info = sample.get("audio")
if not isinstance(audio_info, dict):
print("Skipping sample without audio info")
continue
try:
if audio_info.get("bytes") is not None:
with io.BytesIO(audio_info["bytes"]) as bio:
audio_array, sr = sf.read(bio, dtype="float32")
elif audio_info.get("path"):
audio_array, sr = sf.read(audio_info["path"], dtype="float32")
elif audio_info.get("array") is not None:
audio_array = np.float32(audio_info["array"]) if not isinstance(audio_info["array"], np.ndarray) else audio_info["array"].astype(np.float32)
sr = audio_info.get("sampling_rate", 16000)
else:
print("Skipping sample: unsupported audio format")
continue
except Exception as e:
print(f"Failed to read audio: {e}")
continue

out_path = os.path.join(CACHE_DIR, f"{sid}.wav")
if not os.path.exists(out_path):
os.makedirs(os.path.dirname(out_path), exist_ok=True)
sf.write(out_path, audio_array, sr)

audio_paths.append(out_path)
durations.append(len(audio_array) / sr)

# Normalize reference text
try:
ref_text = data_utils.get_text(sample)
except Exception:
ref_text = sample.get("text", " ")
references.append(data_utils.normalizer(ref_text) or " ")

if args.max_eval_samples is not None and len(audio_paths) >= args.max_eval_samples:
break

# Transcribe
predictions = []
transcription_times = []
print(f"Transcribing with model: {args.model_id}")
for audio_path in tqdm(audio_paths, desc="Transcribing"):
start = time.time()
try:
pred_text = transcribe_with_retry(args.model_id, audio_path)
except Exception as e:
print(f"Failed to transcribe {audio_path}: {e}")
pred_text = " "
elapsed = time.time() - start
transcription_times.append(elapsed)
predictions.append(data_utils.normalizer(pred_text) or " ")
time.sleep(0.1)

if len(predictions) == 0:
print("No samples were successfully processed.")
return

manifest_path = data_utils.write_manifest(
references,
predictions,
args.model_id,
args.dataset_path,
args.dataset,
args.split,
audio_length=durations,
transcription_time=transcription_times,
)
print("Results saved at path:", os.path.abspath(manifest_path))

wer_metric = evaluate.load("wer")
wer = wer_metric.compute(references=references, predictions=predictions)
wer = round(100 * wer, 2)
rtfx = round(sum(durations) / max(1e-9, sum(transcription_times)), 2)
print("WER:", wer, "%")
print("RTFx:", rtfx)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Gemini ASR Evaluation Script")
parser.add_argument(
"--model_id",
type=str,
required=True,
help="Model identifier, must start with 'gemini/'",
)
parser.add_argument(
"--dataset_path",
type=str,
default="esb/datasets",
help="Dataset path. By default, it is `esb/datasets`",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Dataset name.",
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Split of the dataset.",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Number of samples per streamed batch.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Number of samples to be evaluated.",
)
parser.add_argument(
"--no-streaming",
dest="streaming",
action="store_false",
help="Download the entire dataset instead of streaming.",
)
parser.set_defaults(streaming=True)

args = parser.parse_args()

api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
try:
api_key = getpass.getpass("Enter your Gemini API key: ")
except Exception:
api_key = None
if not api_key:
raise RuntimeError("GOOGLE_API_KEY not set and no key provided interactively.")
genai.configure(api_key=api_key)

main(args)
Loading