Skip to content

Commit 69646e7

Browse files
authored
Merge branch 'huggingface:main' into nits-improvements
2 parents 0bab56b + 9232a47 commit 69646e7

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

training/eval.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import evaluate
3-
from transformers import AutoModel, AutoProcessor, pipeline
3+
from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
44

55

66
def clap_similarity(clap_model_name_or_path, texts, audios, device):
@@ -24,13 +24,36 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
2424
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
2525
metric = evaluate.load("wer")
2626
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
27+
28+
return_language = None
29+
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
30+
return_language = True
31+
2732
transcriptions = asr_pipeline(
2833
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
2934
batch_size=int(per_device_eval_batch_size),
35+
return_language=return_language,
3036
)
3137

32-
word_error = 100 * metric.compute(
33-
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
34-
)
38+
if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
39+
tokenizer = asr_pipeline.tokenizer
40+
else:
41+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")
42+
43+
english_normalizer = tokenizer.normalize
44+
basic_normalizer = tokenizer.basic_normalize
45+
46+
normalized_predictions = []
47+
normalized_references = []
48+
49+
for pred, ref in zip(transcriptions, prompts):
50+
normalizer = english_normalizer if hasattr(pred, "language") and pred["language"] == "english" else basic_normalizer
51+
norm_ref = normalizer(ref)
52+
if len(norm_ref) > 0:
53+
norm_pred = normalizer(pred["text"])
54+
normalized_predictions.append(norm_pred)
55+
normalized_references.append(norm_pred)
56+
57+
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
3558

3659
return word_error, [t["text"] for t in transcriptions]

0 commit comments

Comments
 (0)