From acf220accfea62eb86ac57e29a257eded706e71a Mon Sep 17 00:00:00 2001 From: "Yao, Matrix" Date: Fri, 3 Oct 2025 19:44:46 +0000 Subject: [PATCH] fix asr ut failures Signed-off-by: Yao, Matrix --- src/transformers/generation/logits_process.py | 2 +- .../test_pipelines_automatic_speech_recognition.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 7d81501a783d..ce150f790051 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1915,7 +1915,7 @@ def __init__(self, suppress_tokens, device: str = "cpu"): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens) + suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens.to(scores.device)) scores = torch.where(suppress_token_mask, -float("inf"), scores) return scores diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 3235b869a5e4..9b55e52e6fa0 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1104,7 +1104,7 @@ def test_whisper_language(self): def test_speculative_decoding_whisper_non_distil(self): # Load data: dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]") - sample = dataset[0]["audio"] + sample = dataset[0]["audio"].get_all_samples().data # Load model: model_id = "openai/whisper-large-v2" @@ -1133,8 +1133,8 @@ def test_speculative_decoding_whisper_non_distil(self): num_beams=1, ) - transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"] - transcription_ass = pipe(sample)["text"] + transcription_ass = pipe(sample.clone().detach(), generate_kwargs={"assistant_model": assistant_model})["text"] + transcription_non_ass = pipe(sample)["text"] self.assertEqual(transcription_ass, transcription_non_ass) self.assertEqual( @@ -1422,13 +1422,13 @@ def test_whisper_prompted(self): ) dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") - sample = dataset[0]["audio"] + sample = dataset[0]["audio"].get_all_samples().data # prompt the model to misspell "Mr Quilter" as "Mr Quillter" whisper_prompt = "Mr. Quillter." prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device) - unprompted_result = pipe(sample.copy())["text"] + unprompted_result = pipe(sample.clone().detach())["text"] prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"] # fmt: off