@@ -226,6 +226,7 @@ def _forward(
226226 step_confidence = None
227227
228228 decoder_mems_list = None
229+ xatt_scores_list = None
229230 for i in range (max_generation_length ):
230231
231232 if i == 0 :
@@ -234,14 +235,20 @@ def _forward(
234235 i += tgt_len - 1
235236 input_ids = tgt [:, - 1 :]
236237
237- logits , decoder_mems_list , _ = self ._one_step_forward (
238+ logits , decoder_mems_list , new_xatt_scores_list = self ._one_step_forward (
238239 input_ids ,
239240 encoder_hidden_states ,
240241 encoder_input_mask ,
241242 decoder_mems_list ,
242243 i ,
243244 return_scores = return_beam_scores ,
244245 )
246+ if xatt_scores_list is not None :
247+ for layer in range (len (xatt_scores_list )):
248+ xatt_scores_list [layer ] = torch .cat (
249+ (xatt_scores_list [layer ], new_xatt_scores_list [layer ]), dim = 2 )
250+ else :
251+ xatt_scores_list = new_xatt_scores_list
245252
246253 if self .temperature is None : # Greedy decoding
247254 next_tokens = torch .argmax (logits [:, - 1 ], dim = - 1 )
@@ -272,7 +279,7 @@ def _forward(
272279 samples = list (tgt .view (orig_batch_size , self .n_samples , - 1 ))
273280 tgt = tgt [:: self .n_samples ]
274281
275- return tgt , samples , step_confidence_tensor
282+ return tgt , samples , step_confidence_tensor , xatt_scores_list
276283
277284 def __call__ (
278285 self , decoder_input_ids = None , encoder_hidden_states = None , encoder_input_mask = None , return_beam_scores = False
@@ -284,12 +291,12 @@ def __call__(
284291 if not return_beam_scores :
285292 return results
286293 else :
287- prefixes , scores , tgt = results
294+ prefixes , scores , tgt , xatt_scores_list = results
288295 prefixes = prefixes .view (- 1 , self .beam_size , tgt .size (1 )).split (1 , dim = 0 )
289296 scores = scores .view (- 1 , self .beam_size ).split (1 , dim = 0 )
290297 prefixes = [x .squeeze (0 ) for x in prefixes ] # each item is [beam, seq_len]
291298 scores = [x .squeeze (0 ) for x in scores ] # each item is [beam,]
292- return prefixes , scores , tgt
299+ return prefixes , scores , tgt , xatt_scores_list
293300
294301 def freeze (self ) -> None :
295302 """Freeze weights of embedding, decoder, and classification layers to prevent memory leak."""
@@ -413,7 +420,7 @@ def _forward(
413420 tgt , batch_size , max_generation_length = self ._prepare_for_search (decoder_input_ids , encoder_hidden_states )
414421
415422 # generate initial buffer of beam_size prefixes-hypotheses
416- log_probs , decoder_mems_list , _ = self ._one_step_forward (
423+ log_probs , decoder_mems_list , xatt_scores_list = self ._one_step_forward (
417424 tgt , encoder_hidden_states , encoder_input_mask , None , 0
418425 )
419426 scores , prefixes = torch .topk (log_probs .permute (0 , 2 , 1 ), self .beam_size , dim = 1 )
@@ -434,6 +441,10 @@ def _forward(
434441 else :
435442 hidden_size = decoder_mems_list [0 ].size (2 )
436443
444+ # repeat xattn scores
445+ if xatt_scores_list is not None :
446+ xatt_scores_list = [xatt_layer .repeat (self .beam_size , 1 , 1 , 1 ) for xatt_layer in xatt_scores_list ]
447+
437448 # pad_profile tracks finished hypotheses to generate only <pad> tokens
438449 # if <eos> or <pad> has been generated
439450 pad_profile = torch .zeros_like (scores ).long ()
@@ -449,7 +460,7 @@ def _forward(
449460 pad_mask = pad_profile .repeat (1 , self .beam_size )
450461
451462 # generate and score candidates for prefixes continuation
452- log_probs , decoder_mems_list , _ = self ._one_step_forward (
463+ log_probs , decoder_mems_list , next_xatt_scores_list = self ._one_step_forward (
453464 prefixes [:, - 1 :], encoder_hidden_states , encoder_input_mask , decoder_mems_list , i
454465 )
455466 scores_i , prefixes_i = torch .topk (log_probs [:, - 1 , :], self .beam_size , dim = - 1 )
@@ -478,6 +489,19 @@ def _forward(
478489 prefixes_ids = indices_i .unsqueeze (2 ).repeat (1 , 1 , p_len )
479490 prefixes = prefixes .gather (1 , prefixes_ids ).view (- 1 , p_len )
480491
492+ # select xatt scores corresponding to chosen hypotheses
493+ if next_xatt_scores_list is not None :
494+ num_heads = xatt_scores_list [0 ].shape [1 ]
495+ xatt_indices_i = indices_i .unsqueeze (2 ).unsqueeze (3 ).unsqueeze (4 ).repeat (
496+ 1 , 1 , num_heads , p_len - 1 , src_length ) // self .beam_size
497+ for layer in range (len (next_xatt_scores_list )):
498+ xatt_layer_score_i = torch .cat ((xatt_scores_list [layer ], next_xatt_scores_list [layer ]), dim = 2 )
499+ xatt_scores_list [layer ] = xatt_layer_score_i .view (
500+ - 1 , self .beam_size , num_heads , p_len - 1 , src_length
501+ ).gather (1 , xatt_indices_i ).view (
502+ - 1 , num_heads , p_len - 1 , src_length
503+ )
504+
481505 # reshuffle cached decoder memory states to restore the order
482506 # of hypotheses broken after top-k selection
483507 mems_ids = indices_i .unsqueeze (2 ).unsqueeze (3 ).repeat (1 , 1 , p_len - 1 , hidden_size ) // self .beam_size
@@ -501,13 +525,22 @@ def _forward(
501525 # select best performing hypotheses in each element of the batch
502526 len_penalties = self .compute_len_penalty (prefixes_len , self .len_pen )
503527 scores = scores / len_penalties
504- best_guesses = (
505- torch .argmax (scores .view (- 1 , self .beam_size ), dim = 1 , keepdim = True ).repeat (1 , prefixes .size (1 )).unsqueeze (1 )
506- )
507- tgt = prefixes .view (batch_size , self .beam_size , - 1 ).gather (1 , best_guesses ).squeeze (1 )
528+ best_guesses = torch .argmax (scores .view (- 1 , self .beam_size ), dim = 1 , keepdim = True )
529+ tgt_best_guesses = best_guesses .repeat (1 , prefixes .size (1 )).unsqueeze (1 )
530+ tgt = prefixes .view (batch_size , self .beam_size , - 1 ).gather (1 , tgt_best_guesses ).squeeze (1 )
531+
532+ # select xatt scores for best hypotheses
533+ if xatt_scores_list is not None :
534+ _ , num_heads , tgt_len , src_len = xatt_scores_list [0 ].shape
535+ xatt_best_guesses = best_guesses .unsqueeze (2 ).unsqueeze (3 ).unsqueeze (4 ).repeat (
536+ 1 , 1 , num_heads , tgt_len , src_len )
537+ for layer in range (len (xatt_scores_list )):
538+ xatt_scores_list [layer ] = xatt_scores_list [layer ].view (
539+ - 1 , self .beam_size , num_heads , tgt_len , src_len
540+ ).gather (1 , xatt_best_guesses ).squeeze (1 )
508541
509542 if return_beam_scores :
510- return prefixes , scores * len_penalties , tgt
543+ return prefixes , scores * len_penalties , tgt , xatt_scores_list
511544 else :
512545 return tgt
513546
@@ -549,7 +582,7 @@ def _forward(
549582 batch_fusion_states_candidates_list = []
550583
551584 # generate initial buffer of beam_size prefixes-hypotheses
552- log_probs , decoder_mems_list , _ = self ._one_step_forward (
585+ log_probs , decoder_mems_list , xatt_scores_list = self ._one_step_forward (
553586 tgt , encoder_hidden_states , encoder_input_mask , None , 0
554587 )
555588 # get fusion models scores
@@ -585,6 +618,10 @@ def _forward(
585618 else :
586619 hidden_size = decoder_mems_list [0 ].size (2 )
587620
621+ # repeat xattn scores
622+ if xatt_scores_list is not None :
623+ xatt_scores_list = [xatt_layer .repeat (self .beam_size , 1 , 1 , 1 ) for xatt_layer in xatt_scores_list ]
624+
588625 # pad_profile tracks finished hypotheses to generate only <pad> tokens
589626 # if <eos> or <pad> has been generated
590627 pad_profile = torch .zeros_like (scores ).long ()
@@ -600,7 +637,7 @@ def _forward(
600637 pad_mask = pad_profile .repeat (1 , self .beam_size )
601638
602639 # generate and score candidates for prefixes continuation
603- log_probs , decoder_mems_list , _ = self ._one_step_forward (
640+ log_probs , decoder_mems_list , next_xatt_scores_list = self ._one_step_forward (
604641 prefixes [:, - 1 :], encoder_hidden_states , encoder_input_mask , decoder_mems_list , i
605642 )
606643 for fusion_model_idx , fusion_model in enumerate (self .fusion_models ):
@@ -647,6 +684,19 @@ def _forward(
647684 prefixes_ids = indices_i .unsqueeze (2 ).repeat (1 , 1 , p_len )
648685 prefixes = prefixes .gather (1 , prefixes_ids ).view (- 1 , p_len )
649686
687+ # select xatt scores corresponding to chosen hypotheses
688+ if next_xatt_scores_list is not None :
689+ num_heads = xatt_scores_list [0 ].shape [1 ]
690+ xatt_indices_i = indices_i .unsqueeze (2 ).unsqueeze (3 ).unsqueeze (4 ).repeat (
691+ 1 , 1 , num_heads , p_len - 1 , src_length ) // self .beam_size
692+ for layer in range (len (next_xatt_scores_list )):
693+ xatt_layer_score_i = torch .cat ((xatt_scores_list [layer ], next_xatt_scores_list [layer ]), dim = 2 )
694+ xatt_scores_list [layer ] = xatt_layer_score_i .view (
695+ - 1 , self .beam_size , num_heads , p_len - 1 , src_length
696+ ).gather (1 , xatt_indices_i ).view (
697+ - 1 , num_heads , p_len - 1 , src_length
698+ )
699+
650700 # reshuffle cached decoder memory states to restore the order
651701 # of hypotheses broken after top-k selection
652702 mems_ids = indices_i .unsqueeze (2 ).unsqueeze (3 ).repeat (1 , 1 , p_len - 1 , hidden_size ) // self .beam_size
@@ -670,13 +720,22 @@ def _forward(
670720 # select best performing hypotheses in each element of the batch
671721 len_penalties = self .compute_len_penalty (prefixes_len , self .len_pen )
672722 scores = scores / len_penalties
673- best_guesses = (
674- torch .argmax (scores .view (- 1 , self .beam_size ), dim = 1 , keepdim = True ).repeat (1 , prefixes .size (1 )).unsqueeze (1 )
675- )
676- tgt = prefixes .view (batch_size , self .beam_size , - 1 ).gather (1 , best_guesses ).squeeze (1 )
723+ best_guesses = torch .argmax (scores .view (- 1 , self .beam_size ), dim = 1 , keepdim = True )
724+ tgt_best_guesses = best_guesses .repeat (1 , prefixes .size (1 )).unsqueeze (1 )
725+ tgt = prefixes .view (batch_size , self .beam_size , - 1 ).gather (1 , tgt_best_guesses ).squeeze (1 )
726+
727+ # select xatt scores for best hypotheses
728+ if xatt_scores_list is not None :
729+ _ , num_heads , tgt_len , src_len = xatt_scores_list [0 ].shape
730+ xatt_best_guesses = best_guesses .unsqueeze (2 ).unsqueeze (3 ).unsqueeze (4 ).repeat (
731+ 1 , 1 , num_heads , tgt_len , src_len )
732+ for layer in range (len (xatt_scores_list )):
733+ xatt_scores_list [layer ] = xatt_scores_list [layer ].view (
734+ - 1 , self .beam_size , num_heads , tgt_len , src_len
735+ ).gather (1 , xatt_best_guesses ).squeeze (1 )
677736
678737 if return_beam_scores :
679- return prefixes , scores * len_penalties , tgt
738+ return prefixes , scores * len_penalties , tgt , xatt_scores_list
680739 else :
681740 return tgt
682741
0 commit comments