Skip to content

Commit a6b51e7

Browse files
gantevasqu
andauthored
[Whisper + beam search] fix usage of beam_indices (#38259)
* tmp * fix test_tiny_token_timestamp_batch_generation * better comments * test * comments * Apply suggestions from code review Co-authored-by: Anton Vlasjuk <[email protected]> --------- Co-authored-by: Anton Vlasjuk <[email protected]>
1 parent 3e960e0 commit a6b51e7

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

src/transformers/models/whisper/generation_whisper.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -231,42 +231,52 @@ def _extract_token_timestamps(
231231
tensor containing the timestamps in seconds for each predicted token
232232
"""
233233
# Create a list with `decoder_layers` elements, each a tensor of shape
234-
# (batch size, attention_heads, output length, input length).
234+
# (batch size * num beams, attention_heads, output length, input length).
235235
cross_attentions = []
236236
for i in range(self.config.decoder_layers):
237237
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
238238

239239
# Select specific cross-attention layers and heads. This is a tensor
240-
# of shape (batch size, num selected, output length, input length).
240+
# of shape (batch size * num beams, num selected heads, output length, input length).
241241
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
242242
weights = weights.permute([1, 0, 2, 3])
243243

244244
weight_length = None
245245

246246
if "beam_indices" in generate_outputs:
247-
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
248-
# since the beam search strategy chooses the most probable sequences at the end of the search.
249-
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
250-
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
251-
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
252-
253-
# beam search takes `decoder_input_ids` into account in the `beam_indices` length
254-
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
255-
beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
256-
# we actually shift the beam indices here
257-
beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
247+
# If beam search was used, the sequence length of the outputs may not be the real sequence length:
248+
# beam search may end up returning a sequence that finished a few steps earlier while decoding.
249+
# In that case, the `cross_attentions` weights are too long and we have to make sure that they have
250+
# the right `output_length`
258251

259-
weights = weights[:, :, :weight_length]
252+
# get the real sequence length of the longest sequence, crop the beam_indices to the real length
253+
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
254+
beam_indices = generate_outputs.beam_indices[:, :weight_length]
255+
256+
# The first forward pass (prefill) may have processed more than one token and, therefore, contain
257+
# cross-attention weights for several tokens.
258+
# Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights.
259+
if num_input_ids is not None and num_input_ids > 1:
260+
# `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1
261+
weight_length += num_input_ids - 1
262+
beam_indices_first_step_unrolled = (
263+
torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long)
264+
* (beam_indices[:, 0:1])
265+
)
266+
unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1)
267+
else:
268+
unrolled_beam_indices = beam_indices
260269

261270
# If beam index is still -1, it means that the associated token id is EOS
262271
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
263-
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
272+
unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0)
264273

265-
# Select the cross attention from the right beam for each output sequences
274+
# Select the cross attention from the right beam for each output sequence, up to the real sequence
275+
# length (`weight_length`)
266276
weights = torch.stack(
267277
[
268-
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
269-
for i in range(beam_indices.shape[1])
278+
torch.index_select(weights[:, :, i, :], dim=0, index=unrolled_beam_indices[:, i])
279+
for i in range(unrolled_beam_indices.shape[1])
270280
],
271281
dim=2,
272282
)

tests/models/whisper/test_modeling_whisper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2155,7 +2155,6 @@ def test_tiny_token_timestamp_batch_generation(self):
21552155

21562156
# task id and lang id prompts should not have timestamp tokens
21572157
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])
2158-
21592158
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
21602159

21612160
@slow

0 commit comments

Comments
 (0)