@@ -231,42 +231,52 @@ def _extract_token_timestamps(
231231 tensor containing the timestamps in seconds for each predicted token
232232 """
233233 # Create a list with `decoder_layers` elements, each a tensor of shape
234- # (batch size, attention_heads, output length, input length).
234+ # (batch size * num beams , attention_heads, output length, input length).
235235 cross_attentions = []
236236 for i in range (self .config .decoder_layers ):
237237 cross_attentions .append (torch .cat ([x [i ] for x in generate_outputs .cross_attentions ], dim = 2 ))
238238
239239 # Select specific cross-attention layers and heads. This is a tensor
240- # of shape (batch size, num selected, output length, input length).
240+ # of shape (batch size * num beams , num selected heads , output length, input length).
241241 weights = torch .stack ([cross_attentions [l ][:, h ] for l , h in alignment_heads ])
242242 weights = weights .permute ([1 , 0 , 2 , 3 ])
243243
244244 weight_length = None
245245
246246 if "beam_indices" in generate_outputs :
247- # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
248- # since the beam search strategy chooses the most probable sequences at the end of the search.
249- # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
250- weight_length = (generate_outputs .beam_indices != - 1 ).sum (- 1 ).max ()
251- weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
252-
253- # beam search takes `decoder_input_ids` into account in the `beam_indices` length
254- # but forgot to shift the beam_indices by the number of `decoder_input_ids`
255- beam_indices = torch .zeros_like (generate_outputs .beam_indices [:, :weight_length ])
256- # we actually shift the beam indices here
257- beam_indices [:, num_input_ids :] = generate_outputs .beam_indices [:, : weight_length - num_input_ids ]
247+ # If beam search was used, the sequence length of the outputs may not be the real sequence length:
248+ # beam search may end up returning a sequence that finished a few steps earlier while decoding.
249+ # In that case, the `cross_attentions` weights are too long and we have to make sure that they have
250+ # the right `output_length`
258251
259- weights = weights [:, :, :weight_length ]
252+ # get the real sequence length of the longest sequence, crop the beam_indices to the real length
253+ weight_length = (generate_outputs .beam_indices != - 1 ).sum (- 1 ).max ()
254+ beam_indices = generate_outputs .beam_indices [:, :weight_length ]
255+
256+ # The first forward pass (prefill) may have processed more than one token and, therefore, contain
257+ # cross-attention weights for several tokens.
258+ # Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights.
259+ if num_input_ids is not None and num_input_ids > 1 :
260+ # `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1
261+ weight_length += num_input_ids - 1
262+ beam_indices_first_step_unrolled = (
263+ torch .ones (beam_indices .shape [0 ], num_input_ids - 1 , device = beam_indices .device , dtype = torch .long )
264+ * (beam_indices [:, 0 :1 ])
265+ )
266+ unrolled_beam_indices = torch .cat ([beam_indices_first_step_unrolled , beam_indices ], dim = - 1 )
267+ else :
268+ unrolled_beam_indices = beam_indices
260269
261270 # If beam index is still -1, it means that the associated token id is EOS
262271 # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
263- beam_indices = beam_indices .masked_fill (beam_indices == - 1 , 0 )
272+ unrolled_beam_indices = unrolled_beam_indices .masked_fill (unrolled_beam_indices == - 1 , 0 )
264273
265- # Select the cross attention from the right beam for each output sequences
274+ # Select the cross attention from the right beam for each output sequence, up to the real sequence
275+ # length (`weight_length`)
266276 weights = torch .stack (
267277 [
268- torch .index_select (weights [:, :, i , :], dim = 0 , index = beam_indices [:, i ])
269- for i in range (beam_indices .shape [1 ])
278+ torch .index_select (weights [:, :, i , :], dim = 0 , index = unrolled_beam_indices [:, i ])
279+ for i in range (unrolled_beam_indices .shape [1 ])
270280 ],
271281 dim = 2 ,
272282 )
0 commit comments