Skip to content

Commit 76099f6

Browse files
author
sanchit-gandhi
committed
generalise to multilingual
1 parent aca3f5e commit 76099f6

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

training/eval.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import evaluate
3-
from transformers import AutoModel, AutoProcessor, pipeline
3+
from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
44

55

66
def clap_similarity(clap_model_name_or_path, texts, audios, device):
@@ -24,14 +24,35 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
2424
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
2525
metric = evaluate.load("wer")
2626
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
27+
28+
return_language = None
29+
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
30+
return_language = True
31+
2732
transcriptions = asr_pipeline(
2833
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
2934
batch_size=int(per_device_eval_batch_size),
35+
return_language=return_language,
3036
)
3137

32-
normalizer = asr_pipeline.tokenizer.normalize
33-
normalized_predictions = [normalizer(t["text"]) for t in transcriptions]
34-
normalized_references = [normalizer(t) for t in prompts]
38+
if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
39+
tokenizer = asr_pipeline.tokenizer
40+
else:
41+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")
42+
43+
english_normalizer = tokenizer.normalize
44+
basic_normalizer = tokenizer.basic_normalize
45+
46+
normalized_predictions = []
47+
normalized_references = []
48+
49+
for pred, ref in zip(transcriptions, prompts):
50+
normalizer = english_normalizer if hasattr(pred, "language") and pred["language"] == "english" else basic_normalizer
51+
norm_ref = normalizer(ref)
52+
if len(norm_ref) > 0:
53+
norm_pred = normalizer(pred)
54+
normalized_predictions.append(norm_pred)
55+
normalized_references.append(norm_pred)
3556

3657
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
3758

0 commit comments

Comments
 (0)