Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,7 +1915,7 @@ def __init__(self, suppress_tokens, device: str = "cpu"):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens.to(scores.device))
scores = torch.where(suppress_token_mask, -float("inf"), scores)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multi-device cases(like put 2 devices to run):
in current implementation, in assistant decoding case, assistant model will reuse main model's SuppressTokensLogitsProcessor, which place the suppress_tokens in the same device as input_tensor (which is device 0). assistant model will ingest encoder_outputs of the main model and do the decoder(in whisper case), while encoder_outputs may in device 1 but main model's suppress_tokens which is main model's is in device 0, so lead to RuntimeError:

RuntimeError: Expected all tensors to be on the same device, but got test_elements is on xpu:0, different from other tensors on xpu:1 (when checking argument in method wrapper_XPU_isin_Tensor_Tensor)

So based on current implementation(that assistant model shares main model's SuppressTokensLogitsProcessor), I move suppress_tokens to scores.device while doing isin.

return scores

Expand Down
10 changes: 5 additions & 5 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def test_whisper_language(self):
def test_speculative_decoding_whisper_non_distil(self):
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]
sample = dataset[0]["audio"].get_all_samples().data

# Load model:
model_id = "openai/whisper-large-v2"
Expand Down Expand Up @@ -1133,8 +1133,8 @@ def test_speculative_decoding_whisper_non_distil(self):
num_beams=1,
)

transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_ass = pipe(sample)["text"]
transcription_ass = pipe(sample.clone().detach(), generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_non_ass = pipe(sample)["text"]
Comment on lines +1136 to +1137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching the incorrect inversion here!


self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
Expand Down Expand Up @@ -1422,13 +1422,13 @@ def test_whisper_prompted(self):
)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
sample = dataset[0]["audio"].get_all_samples().data

# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
whisper_prompt = "Mr. Quillter."
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device)

unprompted_result = pipe(sample.copy())["text"]
unprompted_result = pipe(sample.clone().detach())["text"]
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]

# fmt: off
Expand Down