Skip to content

Commit 334b8ec

Browse files
Merge pull request #20 from sanchit-gandhi/short-rtf
[Transformers] Compute RTF for short-form datasets
2 parents 55c3d1d + d58797c commit 334b8ec

File tree

8 files changed

+219
-161
lines changed

8 files changed

+219
-161
lines changed

normalizer/eval_utils.py

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import evaluate
66
from collections import defaultdict
77

8+
89
def read_manifest(manifest_path: str):
910
"""
1011
Reads a manifest file (jsonl format) and returns a list of dictionaries containing samples.
1112
"""
1213
data = []
13-
with open(manifest_path, "r", encoding='utf-8') as f:
14+
with open(manifest_path, "r", encoding="utf-8") as f:
1415
for line in f:
1516
if len(line) > 0:
1617
datum = json.loads(line)
@@ -19,7 +20,14 @@ def read_manifest(manifest_path: str):
1920

2021

2122
def write_manifest(
22-
references: list, transcriptions: list, model_id: str, dataset_path: str, dataset_name: str, split: str
23+
references: list,
24+
transcriptions: list,
25+
model_id: str,
26+
dataset_path: str,
27+
dataset_name: str,
28+
split: str,
29+
audio_length: list = None,
30+
transcription_time: list = None,
2331
):
2432
"""
2533
Writes a manifest file (jsonl format) and returns the path to the file.
@@ -31,6 +39,8 @@ def write_manifest(
3139
dataset_path: Path to the dataset.
3240
dataset_name: Name of the dataset.
3341
split: Dataset split name.
42+
audio_length: Length of each audio sample in seconds.
43+
transcription_time: Transcription time of each sample in seconds.
3444
3545
Returns:
3646
Path to the manifest file.
@@ -41,21 +51,46 @@ def write_manifest(
4151

4252
if len(references) != len(transcriptions):
4353
raise ValueError(
44-
f"The number of samples in `ground_truths` ({len(references)}) "
54+
f"The number of samples in `references` ({len(references)}) "
4555
f"must match `transcriptions` ({len(transcriptions)})."
4656
)
4757

48-
basedir = './results/'
58+
if audio_length is not None and len(audio_length) != len(references):
59+
raise ValueError(
60+
f"The number of samples in `audio_length` ({len(audio_length)}) "
61+
f"must match `references` ({len(references)})."
62+
)
63+
if transcription_time is not None and len(transcription_time) != len(references):
64+
raise ValueError(
65+
f"The number of samples in `transcription_time` ({len(transcription_time)}) "
66+
f"must match `references` ({len(references)})."
67+
)
68+
69+
audio_length = (
70+
audio_length if audio_length is not None else len(references) * [None]
71+
)
72+
transcription_time = (
73+
transcription_time
74+
if transcription_time is not None
75+
else len(references) * [None]
76+
)
77+
78+
basedir = "./results/"
4979
if not os.path.exists(basedir):
5080
os.makedirs(basedir)
5181

52-
manifest_path = os.path.join(basedir, f"MODEL_{model_id}_DATASET_{dataset_path}_{dataset_name}_{split}.jsonl")
82+
manifest_path = os.path.join(
83+
basedir, f"MODEL_{model_id}_DATASET_{dataset_path}_{dataset_name}_{split}.jsonl"
84+
)
5385

54-
with open(manifest_path, "w", encoding='utf-8') as f:
55-
for idx, (text, transcript) in enumerate(zip(references, transcriptions)):
86+
with open(manifest_path, "w", encoding="utf-8") as f:
87+
for idx, (text, transcript, audio_length, transcription_time) in enumerate(
88+
zip(references, transcriptions, audio_length, transcription_time)
89+
):
5690
datum = {
5791
"audio_filepath": f"sample_{idx}", # dummy value for Speech Data Processor
58-
"duration": 0.0, # dummy value for Speech Data Processor
92+
"duration": audio_length,
93+
"time": transcription_time,
5994
"text": text,
6095
"pred_text": transcript,
6196
}
@@ -106,7 +141,7 @@ def parse_filepath(fp: str):
106141
dataset_id = ds_fp.replace("DATASET_", "").rstrip(".jsonl")
107142
return model_id, dataset_id
108143

109-
# Compute results per dataset
144+
# Compute WER results per dataset, and RTFx over all datasets
110145
results = {}
111146
wer_metric = evaluate.load("wer")
112147

@@ -117,34 +152,59 @@ def parse_filepath(fp: str):
117152
references = [datum["text"] for datum in manifest]
118153
predictions = [datum["pred_text"] for datum in manifest]
119154

155+
time = [datum["time"] for datum in manifest]
156+
duration = [datum["duration"] for datum in manifest]
157+
compute_rtfx = all(time) and all(duration)
158+
120159
wer = wer_metric.compute(references=references, predictions=predictions)
121160
wer = round(100 * wer, 2)
122161

162+
if compute_rtfx:
163+
audio_length = sum(duration)
164+
inference_time = sum(time)
165+
rtfx = round(sum(duration) / sum(time), 4)
166+
else:
167+
audio_length = inference_time = rtfx = None
168+
123169
result_key = f"{model_id_of_file} | {dataset_id}"
124-
results[result_key] = wer
170+
results[result_key] = {"wer": wer, "audio_length": audio_length, "inference_time": inference_time, "rtfx": rtfx}
125171

126172
print("*" * 80)
127173
print("Results per dataset:")
128174
print("*" * 80)
129175

130176
for k, v in results.items():
131-
print(f"{k}: WER = {v:0.2f} %")
177+
metrics = f"{k}: WER = {v['wer']:0.2f} %"
178+
if v["rtfx"] is not None:
179+
metrics += f", RTFx = {v['rtfx']:0.2f}"
180+
print(metrics)
132181

133182
# composite WER should be computed over all datasets and with the same key
134183
composite_wer = defaultdict(float)
184+
composite_audio_length = defaultdict(float)
185+
composite_inference_time = defaultdict(float)
135186
count_entries = defaultdict(int)
136187
for k, v in results.items():
137188
key = k.split("|")[0].strip()
138-
composite_wer[key] += v
189+
composite_wer[key] += v["wer"]
190+
if v["rtfx"] is not None:
191+
composite_audio_length[key] += v["audio_length"]
192+
composite_inference_time[key] += v["inference_time"]
193+
else:
194+
composite_audio_length[key] = composite_inference_time[key] = None
139195
count_entries[key] += 1
140196

141197
# normalize scores & print
142198
print()
143199
print("*" * 80)
144-
print("Composite WER:")
200+
print("Composite Results:")
145201
print("*" * 80)
146202
for k, v in composite_wer.items():
147203
wer = v / count_entries[k]
148204
print(f"{k}: WER = {wer:0.2f} %")
205+
for k in composite_audio_length:
206+
if composite_audio_length[k] is not None:
207+
rtfx = composite_audio_length[k] / composite_inference_time[k]
208+
print(f"{k}: RTFx = {rtfx:0.2f}")
149209
print("*" * 80)
150210
return composite_wer, results

transformers/calc_rtf.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

transformers/run_data2vec.sh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ do
1313

1414
python run_eval.py \
1515
--model_id=${MODEL_ID} \
16-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
16+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
1717
--dataset="ami" \
1818
--split="test" \
1919
--device=0 \
@@ -23,7 +23,7 @@ do
2323

2424
python run_eval.py \
2525
--model_id=${MODEL_ID} \
26-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
26+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
2727
--dataset="earnings22" \
2828
--split="test" \
2929
--device=0 \
@@ -32,7 +32,7 @@ do
3232

3333
python run_eval.py \
3434
--model_id=${MODEL_ID} \
35-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
35+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
3636
--dataset="gigaspeech" \
3737
--split="test" \
3838
--device=0 \
@@ -41,7 +41,7 @@ do
4141

4242
python run_eval.py \
4343
--model_id=${MODEL_ID} \
44-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
44+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
4545
--dataset="librispeech" \
4646
--split="test.clean" \
4747
--device=0 \
@@ -50,7 +50,7 @@ do
5050

5151
python run_eval.py \
5252
--model_id=${MODEL_ID} \
53-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
53+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
5454
--dataset="librispeech" \
5555
--split="test.other" \
5656
--device=0 \
@@ -59,7 +59,7 @@ do
5959

6060
python run_eval.py \
6161
--model_id=${MODEL_ID} \
62-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
62+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
6363
--dataset="spgispeech" \
6464
--split="test" \
6565
--device=0 \
@@ -68,7 +68,7 @@ do
6868

6969
python run_eval.py \
7070
--model_id=${MODEL_ID} \
71-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
71+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
7272
--dataset="tedlium" \
7373
--split="test" \
7474
--device=0 \
@@ -77,7 +77,7 @@ do
7777

7878
python run_eval.py \
7979
--model_id=${MODEL_ID} \
80-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
80+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
8181
--dataset="voxpopuli" \
8282
--split="test" \
8383
--device=0 \
@@ -86,7 +86,7 @@ do
8686

8787
python run_eval.py \
8888
--model_id=${MODEL_ID} \
89-
--dataset_path="open-asr-leaderboard/datasets-test-only" \
89+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
9090
--dataset="common_voice" \
9191
--split="test" \
9292
--device=0 \

0 commit comments

Comments
 (0)