Skip to content

Commit 61944ea

Browse files
add assembly rev and elevenlab api
1 parent d35005c commit 61944ea

File tree

4 files changed

+473
-68
lines changed

4 files changed

+473
-68
lines changed

openai/run_api.sh

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
export OPENAI_API_KEY="your_api_key"
6+
export ASSEMBLYAI_API_KEY="your_api_key"
7+
export ELEVENLABS_API_KEY="your_api_key"
8+
export REVAI_API_KEY="your_api_key"
9+
10+
MODEL_IDs=(
11+
"openai/gpt-4o-transcribe"
12+
"openai/gpt-4o-mini-transcribe"
13+
"openai/whisper-1"
14+
"assembly/best"
15+
"elevenlabs/scribe_v1"
16+
"revai/machine"
17+
"revai/fusion"
18+
)
19+
20+
num_models=${#MODEL_IDs[@]}
21+
22+
for (( i=0; i<${num_models}; i++ ));
23+
do
24+
MODEL_ID=${MODEL_IDs[$i]}
25+
python run_eval.py \
26+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
27+
--dataset="ami" \
28+
--split="test" \
29+
--model_name ${MODEL_ID}
30+
31+
python run_eval.py \
32+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
33+
--dataset="earnings22" \
34+
--split="test" \
35+
--model_name ${MODEL_ID}
36+
37+
python run_eval.py \
38+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
39+
--dataset="gigaspeech" \
40+
--split="test" \
41+
--model_name ${MODEL_ID}
42+
43+
python run_eval.py \
44+
--dataset_path "hf-audio/esb-datasets-test-only-sorted" \
45+
--dataset "librispeech" \
46+
--split "test.clean" \
47+
--model_name ${MODEL_ID}
48+
49+
python run_eval.py \
50+
--dataset_path "hf-audio/esb-datasets-test-only-sorted" \
51+
--dataset "librispeech" \
52+
--split "test.other" \
53+
--model_name ${MODEL_ID}
54+
55+
python run_eval.py \
56+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
57+
--dataset="spgispeech" \
58+
--split="test" \
59+
--model_name ${MODEL_ID}
60+
61+
python run_eval.py \
62+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
63+
--dataset="tedlium" \
64+
--split="test" \
65+
--model_name ${MODEL_ID}
66+
67+
python run_eval.py \
68+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
69+
--dataset="voxpopuli" \
70+
--split="test" \
71+
--model_name ${MODEL_ID}
72+
73+
# Evaluate results
74+
RUNDIR=`pwd` && \
75+
cd ../normalizer && \
76+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
77+
cd $RUNDIR
78+
79+
done

openai/run_eval.py

Lines changed: 163 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,188 @@
11
import argparse
22
import datasets
33
import evaluate
4-
import io
5-
import json
64
import soundfile as sf
75
import tempfile
86
import time
7+
import os
8+
import requests
99
from tqdm import tqdm
10+
from dotenv import load_dotenv
11+
from io import BytesIO
12+
import assemblyai as aai
1013
import openai
11-
from normalizer import data_utils # must provide .normalizer() and .write_manifest()
14+
from elevenlabs.client import ElevenLabs
15+
from rev_ai import apiclient
16+
from rev_ai.models import CustomVocabulary, CustomerUrlData
17+
from normalizer import data_utils
18+
import concurrent.futures
1219

13-
def transcribe_dataset(
14-
dataset_path, dataset, split,
15-
model_name="whisper-1",
16-
):
17-
# Load dataset
20+
load_dotenv()
21+
22+
def transcribe_with_retry(model_name, audio_file_path, sample, max_retries=10):
23+
retries = 0
24+
while retries <= max_retries:
25+
try:
26+
if model_name.startswith("assembly/"):
27+
aai.settings.api_key = os.getenv("ASSEMBLYAI_API_KEY")
28+
transcriber = aai.Transcriber()
29+
config = aai.TranscriptionConfig(
30+
speech_model=model_name.split("/")[1],
31+
language_code="en",
32+
)
33+
audio_duration = len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"]
34+
if audio_duration < 0.160:
35+
print(f"Skipping audio duration {audio_duration}s")
36+
return "."
37+
transcript = transcriber.transcribe(audio_file_path, config=config)
38+
if transcript.status == aai.TranscriptStatus.error:
39+
raise Exception(f"AssemblyAI transcription error: {transcript.error}")
40+
return transcript.text
41+
42+
elif model_name.startswith("openai/"):
43+
with open(audio_file_path, "rb") as audio_file:
44+
response = openai.Audio.transcribe(
45+
model=model_name.split("/")[1],
46+
file=audio_file,
47+
response_format="text",
48+
language="en",
49+
temperature=0.0,
50+
)
51+
return response.strip()
52+
53+
elif model_name.startswith("elevenlabs/"):
54+
client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY"))
55+
with open(audio_file_path, "rb") as audio_file:
56+
transcription = client.speech_to_text.convert(
57+
file=audio_file,
58+
model_id=model_name.split("/")[1],
59+
language_code="eng",
60+
)
61+
return transcription.text
62+
63+
elif model_name.startswith("revai/"):
64+
access_token = os.getenv("REVAI_API_KEY")
65+
client = apiclient.RevAiAPIClient(access_token)
66+
67+
# Submit job with local file
68+
job = client.submit_job_local_file(
69+
transcriber=model_name.split("/")[1],
70+
filename=audio_file_path,
71+
metadata="benchmarking_job",
72+
remove_disfluencies=True,
73+
remove_atmospherics=True,
74+
)
75+
76+
# Polling until job is done
77+
while True:
78+
job_details = client.get_job_details(job.id)
79+
if job_details.status.name in ["IN_PROGRESS", "TRANSCRIBING"]:
80+
time.sleep(0.1)
81+
continue
82+
elif job_details.status.name == "FAILED":
83+
raise Exception("RevAI transcription failed.")
84+
elif job_details.status.name == "TRANSCRIBED":
85+
break
86+
87+
transcript_object = client.get_transcript_object(job.id)
88+
89+
# Combine all words from all monologues
90+
transcript_text = []
91+
for monologue in transcript_object.monologues:
92+
for element in monologue.elements:
93+
transcript_text.append(element.value)
94+
95+
return "".join(transcript_text) if transcript_text else ""
96+
97+
else:
98+
raise ValueError("Invalid model prefix, must start with 'assembly/', 'openai/', or 'elevenlabs/'")
99+
100+
except Exception as e:
101+
retries += 1
102+
if retries > max_retries:
103+
return "."
104+
105+
sf.write(audio_file_path, sample["audio"]["array"], sample["audio"]["sampling_rate"], format="WAV")
106+
delay = 1
107+
print(f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})")
108+
time.sleep(delay)
109+
110+
111+
def transcribe_dataset(dataset_path, dataset, split, model_name, max_samples=None, max_workers=4):
18112
ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False)
113+
ds = data_utils.prepare_data(ds)
114+
if max_samples:
115+
ds = ds.take(max_samples)
116+
117+
results = {"references": [], "predictions": [], "audio_length_s": [], "transcription_time_s": []}
118+
119+
print(f"Transcribing with model: {model_name}")
19120

20-
# Track results
21-
all_results = {
22-
"references": [],
23-
"predictions": [],
24-
"audio_length_s": [],
25-
"transcription_time_s": [],
26-
}
27-
28-
print(f"Transcribing with OpenAI model: {model_name}")
29-
30-
for i, sample in tqdm(enumerate(ds), total=len(ds), desc="Transcribing"):
31-
# Get reference text, use empty string if not present
32-
reference = sample.get("text", "").strip()
33-
34-
# Write temp .wav file
35-
with tempfile.NamedTemporaryFile(suffix=".wav") as tmpfile:
121+
def process_sample(sample):
122+
reference = sample.get("norm_text", "").strip() or " "
123+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
36124
sf.write(tmpfile.name, sample["audio"]["array"], sample["audio"]["sampling_rate"], format="WAV")
125+
tmp_path = tmpfile.name
126+
127+
start = time.time()
128+
try:
129+
transcription = transcribe_with_retry(model_name, tmp_path, sample)
130+
except Exception as e:
131+
print(f"Failed to transcribe after retries: {e}")
132+
os.unlink(tmp_path)
133+
return None
134+
finally:
135+
if os.path.exists(tmp_path):
136+
os.unlink(tmp_path)
137+
else:
138+
print(f"File {tmp_path} does not exist")
139+
140+
transcription_time = time.time() - start
141+
audio_duration = len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"]
142+
transcription = data_utils.normalizer(transcription) or " "
143+
return reference, transcription, audio_duration, transcription_time
144+
145+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
146+
future_to_sample = {executor.submit(process_sample, sample): sample for sample in ds}
147+
for future in tqdm(concurrent.futures.as_completed(future_to_sample), total=len(future_to_sample), desc="Transcribing"):
148+
result = future.result()
149+
if result:
150+
reference, transcription, audio_duration, transcription_time = result
151+
results["predictions"].append(transcription)
152+
results["references"].append(reference)
153+
results["audio_length_s"].append(audio_duration)
154+
results["transcription_time_s"].append(transcription_time)
37155

38-
start = time.time()
39-
response = openai.Audio.transcribe(
40-
model=model_name,
41-
file=tmpfile,
42-
response_format="text"
43-
)
44-
end = time.time()
45-
46-
transcription = response.strip()
47-
reference = sample["text"]
48-
audio_duration = sample["audio_length_s"]
49-
transcription_time = end - start
50-
51-
transcription = data_utils.normalizer(transcription)
52-
reference = data_utils.normalizer(reference)
53-
# Store
54-
all_results["predictions"].append(transcription)
55-
all_results["references"].append(reference)
56-
all_results["audio_length_s"].append(audio_duration)
57-
all_results["transcription_time_s"].append(transcription_time)
58-
59-
# Save results to manifest
60156
manifest_path = data_utils.write_manifest(
61-
all_results["references"],
62-
all_results["predictions"],
63-
model_name,
157+
results["references"],
158+
results["predictions"],
159+
model_name.replace("/", "-"),
64160
dataset_path,
65161
dataset,
66162
split,
67-
audio_length=all_results["audio_length_s"],
68-
transcription_time=all_results["transcription_time_s"],
163+
audio_length=results["audio_length_s"],
164+
transcription_time=results["transcription_time_s"],
69165
)
166+
70167
print("Results saved at path:", manifest_path)
71168

72-
# Evaluate
73169
wer_metric = evaluate.load("wer")
74-
wer = wer_metric.compute(
75-
references=all_results["references"],
76-
predictions=all_results["predictions"]
77-
)
78-
wer = round(100 * wer, 2)
79-
rtfx = round(
80-
sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]),
81-
2
82-
)
170+
wer = wer_metric.compute(references=results["references"], predictions=results["predictions"])
171+
wer_percent = round(100 * wer, 2)
172+
rtfx = round(sum(results["audio_length_s"]) / sum(results["transcription_time_s"]), 2)
83173

84-
print("WER:", wer, "%", "RTFx:", rtfx)
174+
print("WER:", wer_percent, "%")
175+
print("RTFx:", rtfx)
85176

86-
if __name__ == "__main__":
87-
parser = argparse.ArgumentParser(description="Transcribe using OpenAI Whisper API")
88177

89-
parser.add_argument("--dataset_path", required=True, help="Dataset path or name")
90-
parser.add_argument("--dataset", required=True, help="Subset name of the dataset")
91-
parser.add_argument("--split", default="test", help="Dataset split")
92-
parser.add_argument("--model_name", default="whisper-1", help="OpenAI model name")
178+
if __name__ == "__main__":
179+
parser = argparse.ArgumentParser(description="Unified Transcription Script with Concurrency")
180+
parser.add_argument("--dataset_path", required=True)
181+
parser.add_argument("--dataset", required=True)
182+
parser.add_argument("--split", default="test")
183+
parser.add_argument("--model_name", required=True, help="Prefix model name with 'assembly/', 'openai/', or 'elevenlabs/'")
184+
parser.add_argument("--max_samples", type=int, default=None)
185+
parser.add_argument("--max_workers", type=int, default=50, help="Number of concurrent threads")
93186

94187
args = parser.parse_args()
95188

@@ -98,4 +191,6 @@ def transcribe_dataset(
98191
dataset=args.dataset,
99192
split=args.split,
100193
model_name=args.model_name,
194+
max_samples=args.max_samples,
195+
max_workers=args.max_workers,
101196
)

0 commit comments

Comments
 (0)