1
1
import torch
2
2
import evaluate
3
- from transformers import AutoModel , AutoProcessor , pipeline
3
+ from transformers import AutoModel , AutoProcessor , pipeline , WhisperForConditionalGeneration , WhisperTokenizer , WhisperTokenizerFast
4
4
5
5
6
6
def clap_similarity (clap_model_name_or_path , texts , audios , device ):
@@ -24,13 +24,36 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
24
24
def wer (asr_model_name_or_path , prompts , audios , device , per_device_eval_batch_size , sampling_rate ):
25
25
metric = evaluate .load ("wer" )
26
26
asr_pipeline = pipeline (model = asr_model_name_or_path , device = device )
27
+
28
+ return_language = None
29
+ if isinstance (asr_pipeline .model , WhisperForConditionalGeneration ):
30
+ return_language = True
31
+
27
32
transcriptions = asr_pipeline (
28
33
[{"raw" : audio , "sampling_rate" : sampling_rate } for audio in audios ],
29
34
batch_size = int (per_device_eval_batch_size ),
35
+ return_language = return_language ,
30
36
)
31
37
32
- word_error = 100 * metric .compute (
33
- predictions = [t ["text" ].lower () for t in transcriptions ], references = [t .lower () for t in prompts ]
34
- )
38
+ if isinstance (asr_pipeline .tokenizer , (WhisperTokenizer , WhisperTokenizerFast )):
39
+ tokenizer = asr_pipeline .tokenizer
40
+ else :
41
+ tokenizer = WhisperTokenizer .from_pretrained ("openai/whisper-large-v3" )
42
+
43
+ english_normalizer = tokenizer .normalize
44
+ basic_normalizer = tokenizer .basic_normalize
45
+
46
+ normalized_predictions = []
47
+ normalized_references = []
48
+
49
+ for pred , ref in zip (transcriptions , prompts ):
50
+ normalizer = english_normalizer if hasattr (pred , "language" ) and pred ["language" ] == "english" else basic_normalizer
51
+ norm_ref = normalizer (ref )
52
+ if len (norm_ref ) > 0 :
53
+ norm_pred = normalizer (pred ["text" ])
54
+ normalized_predictions .append (norm_pred )
55
+ normalized_references .append (norm_pred )
56
+
57
+ word_error = 100 * metric .compute (predictions = normalized_predictions , references = normalized_references )
35
58
36
59
return word_error , [t ["text" ] for t in transcriptions ]
0 commit comments