Skip to content

Commit 59dd8be

Browse files
Merge pull request #58 from huggingface/api_openai
Adding api based models
2 parents afba94c + 0f1fc14 commit 59dd8be

File tree

2 files changed

+368
-0
lines changed

2 files changed

+368
-0
lines changed

api/run_api.sh

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
export OPENAI_API_KEY="your_api_key"
6+
export ASSEMBLYAI_API_KEY="your_api_key"
7+
export ELEVENLABS_API_KEY="your_api_key"
8+
export REVAI_API_KEY="your_api_key"
9+
10+
MODEL_IDs=(
11+
"openai/gpt-4o-transcribe"
12+
"openai/gpt-4o-mini-transcribe"
13+
"openai/whisper-1"
14+
"assembly/best"
15+
"elevenlabs/scribe_v1"
16+
"revai/machine" # please use --use_url=True
17+
"revai/fusion" # please use --use_url=True
18+
)
19+
20+
num_models=${#MODEL_IDs[@]}
21+
22+
for (( i=0; i<${num_models}; i++ ));
23+
do
24+
MODEL_ID=${MODEL_IDs[$i]}
25+
python run_eval.py \
26+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
27+
--dataset="ami" \
28+
--split="test" \
29+
--model_name ${MODEL_ID}
30+
31+
python run_eval.py \
32+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
33+
--dataset="earnings22" \
34+
--split="test" \
35+
--model_name ${MODEL_ID}
36+
37+
python run_eval.py \
38+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
39+
--dataset="gigaspeech" \
40+
--split="test" \
41+
--model_name ${MODEL_ID}
42+
43+
python run_eval.py \
44+
--dataset_path "hf-audio/esb-datasets-test-only-sorted" \
45+
--dataset "librispeech" \
46+
--split "test.clean" \
47+
--model_name ${MODEL_ID}
48+
49+
python run_eval.py \
50+
--dataset_path "hf-audio/esb-datasets-test-only-sorted" \
51+
--dataset "librispeech" \
52+
--split "test.other" \
53+
--model_name ${MODEL_ID}
54+
55+
python run_eval.py \
56+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
57+
--dataset="spgispeech" \
58+
--split="test" \
59+
--model_name ${MODEL_ID}
60+
61+
python run_eval.py \
62+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
63+
--dataset="tedlium" \
64+
--split="test" \
65+
--model_name ${MODEL_ID}
66+
67+
python run_eval.py \
68+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
69+
--dataset="voxpopuli" \
70+
--split="test" \
71+
--model_name ${MODEL_ID}
72+
73+
# Evaluate results
74+
RUNDIR=`pwd` && \
75+
cd ../normalizer && \
76+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
77+
cd $RUNDIR
78+
79+
done

api/run_eval.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
import argparse
2+
import datasets
3+
import evaluate
4+
import soundfile as sf
5+
import tempfile
6+
import time
7+
import os
8+
import requests
9+
from tqdm import tqdm
10+
from dotenv import load_dotenv
11+
from io import BytesIO
12+
import assemblyai as aai
13+
import openai
14+
from elevenlabs.client import ElevenLabs
15+
from rev_ai import apiclient
16+
from rev_ai.models import CustomVocabulary, CustomerUrlData
17+
from normalizer import data_utils
18+
import concurrent.futures
19+
20+
load_dotenv()
21+
22+
def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20):
23+
API_URL = "https://datasets-server.huggingface.co/rows"
24+
25+
size_url = f"https://datasets-server.huggingface.co/size?dataset={dataset_path}&config={dataset}&split={split}"
26+
size_response = requests.get(size_url).json()
27+
total_rows = size_response['size']['config']['num_rows']
28+
audio_urls = []
29+
for offset in tqdm(range(0, total_rows, batch_size), desc="Fetching audio URLs"):
30+
params = {
31+
"dataset": dataset_path,
32+
"config": dataset,
33+
"split": split,
34+
"offset": offset,
35+
"length": min(batch_size, total_rows - offset)
36+
}
37+
38+
retries = 0
39+
while retries <= max_retries:
40+
try:
41+
response = requests.get(API_URL, params=params)
42+
response.raise_for_status()
43+
data = response.json()
44+
audio_urls.extend(data['rows'])
45+
break
46+
except (requests.exceptions.RequestException, ValueError) as e:
47+
retries += 1
48+
print(f"Error fetching data: {e}, retrying ({retries}/{max_retries})...")
49+
time.sleep(10)
50+
if retries >= max_retries:
51+
raise Exception("Max retries exceeded while fetching data.")
52+
time.sleep(1)
53+
return audio_urls
54+
55+
def transcribe_with_retry(model_name, audio_file_path, sample, max_retries=10, use_url=False):
56+
retries = 0
57+
while retries <= max_retries:
58+
try:
59+
if model_name.startswith("assembly/"):
60+
aai.settings.api_key = os.getenv("ASSEMBLYAI_API_KEY")
61+
transcriber = aai.Transcriber()
62+
config = aai.TranscriptionConfig(
63+
speech_model=model_name.split("/")[1],
64+
language_code="en",
65+
)
66+
if use_url:
67+
audio_url = sample['row']['audio'][0]['src']
68+
audio_duration = sample['row']['audio_length_s']
69+
if audio_duration < 0.160:
70+
print(f"Skipping audio duration {audio_duration}s")
71+
return "."
72+
transcript = transcriber.transcribe(audio_url, config=config)
73+
else:
74+
audio_duration = len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"]
75+
if audio_duration < 0.160:
76+
print(f"Skipping audio duration {audio_duration}s")
77+
return "."
78+
transcript = transcriber.transcribe(audio_file_path, config=config)
79+
80+
if transcript.status == aai.TranscriptStatus.error:
81+
raise Exception(f"AssemblyAI transcription error: {transcript.error}")
82+
return transcript.text
83+
84+
elif model_name.startswith("openai/"):
85+
if use_url:
86+
response = requests.get(sample['row']['audio'][0]['src'])
87+
audio_data = BytesIO(response.content)
88+
response = openai.Audio.transcribe(
89+
model=model_name.split("/")[1],
90+
file=audio_data,
91+
response_format="text",
92+
language="en",
93+
temperature=0.0,
94+
)
95+
else:
96+
with open(audio_file_path, "rb") as audio_file:
97+
response = openai.Audio.transcribe(
98+
model=model_name.split("/")[1],
99+
file=audio_file,
100+
response_format="text",
101+
language="en",
102+
temperature=0.0,
103+
)
104+
return response.strip()
105+
106+
elif model_name.startswith("elevenlabs/"):
107+
client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY"))
108+
if use_url:
109+
response = requests.get(sample['row']['audio'][0]['src'])
110+
audio_data = BytesIO(response.content)
111+
transcription = client.speech_to_text.convert(
112+
file=audio_data,
113+
model_id=model_name.split("/")[1],
114+
language_code="eng",
115+
tag_audio_events=True,
116+
117+
)
118+
else:
119+
with open(audio_file_path, "rb") as audio_file:
120+
transcription = client.speech_to_text.convert(
121+
file=audio_file,
122+
model_id=model_name.split("/")[1],
123+
language_code="eng",
124+
tag_audio_events=True,
125+
)
126+
return transcription.text
127+
128+
elif model_name.startswith("revai/"):
129+
access_token = os.getenv("REVAI_API_KEY")
130+
client = apiclient.RevAiAPIClient(access_token)
131+
132+
if use_url:
133+
# Submit job with URL for Rev.ai
134+
job = client.submit_job_url(
135+
transcriber=model_name.split("/")[1],
136+
source_config=CustomerUrlData(sample['row']['audio'][0]['src']),
137+
metadata="benchmarking_job",
138+
)
139+
else:
140+
# Submit job with local file
141+
job = client.submit_job_local_file(
142+
transcriber=model_name.split("/")[1],
143+
filename=audio_file_path,
144+
metadata="benchmarking_job",
145+
)
146+
147+
# Polling until job is done
148+
while True:
149+
job_details = client.get_job_details(job.id)
150+
if job_details.status.name in ["IN_PROGRESS", "TRANSCRIBING"]:
151+
time.sleep(0.1)
152+
continue
153+
elif job_details.status.name == "FAILED":
154+
raise Exception("RevAI transcription failed.")
155+
elif job_details.status.name == "TRANSCRIBED":
156+
break
157+
158+
transcript_object = client.get_transcript_object(job.id)
159+
160+
# Combine all words from all monologues
161+
transcript_text = []
162+
for monologue in transcript_object.monologues:
163+
for element in monologue.elements:
164+
transcript_text.append(element.value)
165+
166+
return "".join(transcript_text) if transcript_text else ""
167+
168+
else:
169+
raise ValueError("Invalid model prefix, must start with 'assembly/', 'openai/', 'elevenlabs/' or 'revai/'")
170+
171+
except Exception as e:
172+
retries += 1
173+
if retries > max_retries:
174+
return "."
175+
176+
if not use_url:
177+
sf.write(audio_file_path, sample["audio"]["array"], sample["audio"]["sampling_rate"], format="WAV")
178+
delay = 1
179+
print(f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})")
180+
time.sleep(delay)
181+
182+
183+
def transcribe_dataset(dataset_path, dataset, split, model_name, use_url=False, max_samples=None, max_workers=4):
184+
if use_url:
185+
audio_rows = fetch_audio_urls(dataset_path, dataset, split)
186+
if max_samples:
187+
audio_rows = audio_rows[:max_samples]
188+
ds = audio_rows
189+
else:
190+
ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False)
191+
ds = data_utils.prepare_data(ds)
192+
if max_samples:
193+
ds = ds.take(max_samples)
194+
195+
results = {"references": [], "predictions": [], "audio_length_s": [], "transcription_time_s": []}
196+
197+
print(f"Transcribing with model: {model_name}")
198+
199+
def process_sample(sample):
200+
if use_url:
201+
reference = sample['row']['text'].strip() or " "
202+
audio_duration = sample['row']['audio_length_s']
203+
start = time.time()
204+
try:
205+
transcription = transcribe_with_retry(model_name, None, sample, use_url=True)
206+
except Exception as e:
207+
print(f"Failed to transcribe after retries: {e}")
208+
return None
209+
210+
else:
211+
reference = sample.get("norm_text", "").strip() or " "
212+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
213+
sf.write(tmpfile.name, sample["audio"]["array"], sample["audio"]["sampling_rate"], format="WAV")
214+
tmp_path = tmpfile.name
215+
audio_duration = len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"]
216+
217+
start = time.time()
218+
try:
219+
transcription = transcribe_with_retry(model_name, tmp_path, sample, use_url=False)
220+
except Exception as e:
221+
print(f"Failed to transcribe after retries: {e}")
222+
os.unlink(tmp_path)
223+
return None
224+
finally:
225+
if os.path.exists(tmp_path):
226+
os.unlink(tmp_path)
227+
else:
228+
print(f"File {tmp_path} does not exist")
229+
230+
transcription_time = time.time() - start
231+
return reference, transcription, audio_duration, transcription_time
232+
233+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
234+
future_to_sample = {executor.submit(process_sample, sample): sample for sample in ds}
235+
for future in tqdm(concurrent.futures.as_completed(future_to_sample), total=len(future_to_sample), desc="Transcribing"):
236+
result = future.result()
237+
if result:
238+
reference, transcription, audio_duration, transcription_time = result
239+
results["predictions"].append(transcription)
240+
results["references"].append(reference)
241+
results["audio_length_s"].append(audio_duration)
242+
results["transcription_time_s"].append(transcription_time)
243+
244+
results["predictions"] = [data_utils.normalizer(transcription) or " " for transcription in results["predictions"]]
245+
results["references"] = [data_utils.normalizer(reference) or " " for reference in results["references"]]
246+
247+
manifest_path = data_utils.write_manifest(
248+
results["references"],
249+
results["predictions"],
250+
model_name.replace("/", "-"),
251+
dataset_path,
252+
dataset,
253+
split,
254+
audio_length=results["audio_length_s"],
255+
transcription_time=results["transcription_time_s"],
256+
)
257+
258+
print("Results saved at path:", manifest_path)
259+
260+
wer_metric = evaluate.load("wer")
261+
wer = wer_metric.compute(references=results["references"], predictions=results["predictions"])
262+
wer_percent = round(100 * wer, 2)
263+
rtfx = round(sum(results["audio_length_s"]) / sum(results["transcription_time_s"]), 2)
264+
265+
print("WER:", wer_percent, "%")
266+
print("RTFx:", rtfx)
267+
268+
269+
if __name__ == "__main__":
270+
parser = argparse.ArgumentParser(description="Unified Transcription Script with Concurrency")
271+
parser.add_argument("--dataset_path", required=True)
272+
parser.add_argument("--dataset", required=True)
273+
parser.add_argument("--split", default="test")
274+
parser.add_argument("--model_name", required=True, help="Prefix model name with 'assembly/', 'openai/', or 'elevenlabs/'")
275+
parser.add_argument("--max_samples", type=int, default=None)
276+
parser.add_argument("--max_workers", type=int, default=300, help="Number of concurrent threads")
277+
parser.add_argument("--use_url", action="store_true", help="Use URL-based audio fetching instead of datasets")
278+
279+
args = parser.parse_args()
280+
281+
transcribe_dataset(
282+
dataset_path=args.dataset_path,
283+
dataset=args.dataset,
284+
split=args.split,
285+
model_name=args.model_name,
286+
use_url=args.use_url,
287+
max_samples=args.max_samples,
288+
max_workers=args.max_workers,
289+
)

0 commit comments

Comments
 (0)