Skip to content

Commit 2de6160

Browse files
committed
Add cross-attention to output hypotheses
1 parent 661af02 commit 2de6160

File tree

5 files changed

+135
-34
lines changed

5 files changed

+135
-34
lines changed

nemo/collections/asr/modules/transformer/transformer_generators.py

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _forward(
226226
step_confidence = None
227227

228228
decoder_mems_list = None
229+
xatt_scores_list = None
229230
for i in range(max_generation_length):
230231

231232
if i == 0:
@@ -234,14 +235,20 @@ def _forward(
234235
i += tgt_len - 1
235236
input_ids = tgt[:, -1:]
236237

237-
logits, decoder_mems_list, _ = self._one_step_forward(
238+
logits, decoder_mems_list, new_xatt_scores_list = self._one_step_forward(
238239
input_ids,
239240
encoder_hidden_states,
240241
encoder_input_mask,
241242
decoder_mems_list,
242243
i,
243244
return_scores=return_beam_scores,
244245
)
246+
if xatt_scores_list is not None:
247+
for layer in range(len(xatt_scores_list)):
248+
xatt_scores_list[layer] = torch.cat(
249+
(xatt_scores_list[layer], new_xatt_scores_list[layer]), dim=2)
250+
else:
251+
xatt_scores_list = new_xatt_scores_list
245252

246253
if self.temperature is None: # Greedy decoding
247254
next_tokens = torch.argmax(logits[:, -1], dim=-1)
@@ -272,7 +279,7 @@ def _forward(
272279
samples = list(tgt.view(orig_batch_size, self.n_samples, -1))
273280
tgt = tgt[:: self.n_samples]
274281

275-
return tgt, samples, step_confidence_tensor
282+
return tgt, samples, step_confidence_tensor, xatt_scores_list
276283

277284
def __call__(
278285
self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False
@@ -284,12 +291,12 @@ def __call__(
284291
if not return_beam_scores:
285292
return results
286293
else:
287-
prefixes, scores, tgt = results
294+
prefixes, scores, tgt, xatt_scores_list = results
288295
prefixes = prefixes.view(-1, self.beam_size, tgt.size(1)).split(1, dim=0)
289296
scores = scores.view(-1, self.beam_size).split(1, dim=0)
290297
prefixes = [x.squeeze(0) for x in prefixes] # each item is [beam, seq_len]
291298
scores = [x.squeeze(0) for x in scores] # each item is [beam,]
292-
return prefixes, scores, tgt
299+
return prefixes, scores, tgt, xatt_scores_list
293300

294301
def freeze(self) -> None:
295302
"""Freeze weights of embedding, decoder, and classification layers to prevent memory leak."""
@@ -413,7 +420,7 @@ def _forward(
413420
tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states)
414421

415422
# generate initial buffer of beam_size prefixes-hypotheses
416-
log_probs, decoder_mems_list, _ = self._one_step_forward(
423+
log_probs, decoder_mems_list, xatt_scores_list = self._one_step_forward(
417424
tgt, encoder_hidden_states, encoder_input_mask, None, 0
418425
)
419426
scores, prefixes = torch.topk(log_probs.permute(0, 2, 1), self.beam_size, dim=1)
@@ -434,6 +441,10 @@ def _forward(
434441
else:
435442
hidden_size = decoder_mems_list[0].size(2)
436443

444+
# repeat xattn scores
445+
if xatt_scores_list is not None:
446+
xatt_scores_list = [xatt_layer.repeat(self.beam_size, 1, 1, 1) for xatt_layer in xatt_scores_list]
447+
437448
# pad_profile tracks finished hypotheses to generate only <pad> tokens
438449
# if <eos> or <pad> has been generated
439450
pad_profile = torch.zeros_like(scores).long()
@@ -449,7 +460,7 @@ def _forward(
449460
pad_mask = pad_profile.repeat(1, self.beam_size)
450461

451462
# generate and score candidates for prefixes continuation
452-
log_probs, decoder_mems_list, _ = self._one_step_forward(
463+
log_probs, decoder_mems_list, next_xatt_scores_list = self._one_step_forward(
453464
prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i
454465
)
455466
scores_i, prefixes_i = torch.topk(log_probs[:, -1, :], self.beam_size, dim=-1)
@@ -478,6 +489,19 @@ def _forward(
478489
prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len)
479490
prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len)
480491

492+
# select xatt scores corresponding to chosen hypotheses
493+
if next_xatt_scores_list is not None:
494+
num_heads = xatt_scores_list[0].shape[1]
495+
xatt_indices_i = indices_i.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(
496+
1, 1, num_heads, p_len - 1, src_length) // self.beam_size
497+
for layer in range(len(next_xatt_scores_list)):
498+
xatt_layer_score_i = torch.cat((xatt_scores_list[layer], next_xatt_scores_list[layer]), dim=2)
499+
xatt_scores_list[layer] = xatt_layer_score_i.view(
500+
-1, self.beam_size, num_heads, p_len - 1, src_length
501+
).gather(1, xatt_indices_i).view(
502+
-1, num_heads, p_len - 1, src_length
503+
)
504+
481505
# reshuffle cached decoder memory states to restore the order
482506
# of hypotheses broken after top-k selection
483507
mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, hidden_size) // self.beam_size
@@ -501,13 +525,22 @@ def _forward(
501525
# select best performing hypotheses in each element of the batch
502526
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
503527
scores = scores / len_penalties
504-
best_guesses = (
505-
torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True).repeat(1, prefixes.size(1)).unsqueeze(1)
506-
)
507-
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses).squeeze(1)
528+
best_guesses = torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True)
529+
tgt_best_guesses = best_guesses.repeat(1, prefixes.size(1)).unsqueeze(1)
530+
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, tgt_best_guesses).squeeze(1)
531+
532+
# select xatt scores for best hypotheses
533+
if xatt_scores_list is not None:
534+
_, num_heads, tgt_len, src_len = xatt_scores_list[0].shape
535+
xatt_best_guesses = best_guesses.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(
536+
1, 1, num_heads, tgt_len, src_len)
537+
for layer in range(len(xatt_scores_list)):
538+
xatt_scores_list[layer] = xatt_scores_list[layer].view(
539+
-1, self.beam_size, num_heads, tgt_len, src_len
540+
).gather(1, xatt_best_guesses).squeeze(1)
508541

509542
if return_beam_scores:
510-
return prefixes, scores * len_penalties, tgt
543+
return prefixes, scores * len_penalties, tgt, xatt_scores_list
511544
else:
512545
return tgt
513546

@@ -549,7 +582,7 @@ def _forward(
549582
batch_fusion_states_candidates_list = []
550583

551584
# generate initial buffer of beam_size prefixes-hypotheses
552-
log_probs, decoder_mems_list, _ = self._one_step_forward(
585+
log_probs, decoder_mems_list, xatt_scores_list = self._one_step_forward(
553586
tgt, encoder_hidden_states, encoder_input_mask, None, 0
554587
)
555588
# get fusion models scores
@@ -585,6 +618,10 @@ def _forward(
585618
else:
586619
hidden_size = decoder_mems_list[0].size(2)
587620

621+
# repeat xattn scores
622+
if xatt_scores_list is not None:
623+
xatt_scores_list = [xatt_layer.repeat(self.beam_size, 1, 1, 1) for xatt_layer in xatt_scores_list]
624+
588625
# pad_profile tracks finished hypotheses to generate only <pad> tokens
589626
# if <eos> or <pad> has been generated
590627
pad_profile = torch.zeros_like(scores).long()
@@ -600,7 +637,7 @@ def _forward(
600637
pad_mask = pad_profile.repeat(1, self.beam_size)
601638

602639
# generate and score candidates for prefixes continuation
603-
log_probs, decoder_mems_list, _ = self._one_step_forward(
640+
log_probs, decoder_mems_list, next_xatt_scores_list = self._one_step_forward(
604641
prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i
605642
)
606643
for fusion_model_idx, fusion_model in enumerate(self.fusion_models):
@@ -647,6 +684,19 @@ def _forward(
647684
prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len)
648685
prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len)
649686

687+
# select xatt scores corresponding to chosen hypotheses
688+
if next_xatt_scores_list is not None:
689+
num_heads = xatt_scores_list[0].shape[1]
690+
xatt_indices_i = indices_i.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(
691+
1, 1, num_heads, p_len - 1, src_length) // self.beam_size
692+
for layer in range(len(next_xatt_scores_list)):
693+
xatt_layer_score_i = torch.cat((xatt_scores_list[layer], next_xatt_scores_list[layer]), dim=2)
694+
xatt_scores_list[layer] = xatt_layer_score_i.view(
695+
-1, self.beam_size, num_heads, p_len - 1, src_length
696+
).gather(1, xatt_indices_i).view(
697+
-1, num_heads, p_len - 1, src_length
698+
)
699+
650700
# reshuffle cached decoder memory states to restore the order
651701
# of hypotheses broken after top-k selection
652702
mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, hidden_size) // self.beam_size
@@ -670,13 +720,22 @@ def _forward(
670720
# select best performing hypotheses in each element of the batch
671721
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
672722
scores = scores / len_penalties
673-
best_guesses = (
674-
torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True).repeat(1, prefixes.size(1)).unsqueeze(1)
675-
)
676-
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses).squeeze(1)
723+
best_guesses = torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True)
724+
tgt_best_guesses = best_guesses.repeat(1, prefixes.size(1)).unsqueeze(1)
725+
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, tgt_best_guesses).squeeze(1)
726+
727+
# select xatt scores for best hypotheses
728+
if xatt_scores_list is not None:
729+
_, num_heads, tgt_len, src_len = xatt_scores_list[0].shape
730+
xatt_best_guesses = best_guesses.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(
731+
1, 1, num_heads, tgt_len, src_len)
732+
for layer in range(len(xatt_scores_list)):
733+
xatt_scores_list[layer] = xatt_scores_list[layer].view(
734+
-1, self.beam_size, num_heads, tgt_len, src_len
735+
).gather(1, xatt_best_guesses).squeeze(1)
677736

678737
if return_beam_scores:
679-
return prefixes, scores * len_penalties, tgt
738+
return prefixes, scores * len_penalties, tgt, xatt_scores_list
680739
else:
681740
return tgt
682741

nemo/collections/asr/parts/submodules/multitask_beam_decoding.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333

3434

3535
def pack_hypotheses(
36-
hypotheses: List[Hypothesis], beam_hypotheses: torch.Tensor, scores: List[Optional[float]]
36+
hypotheses: List[Hypothesis],
37+
beam_hypotheses: torch.Tensor,
38+
scores: List[Optional[float]],
39+
xatt_scores_list: List[torch.Tensor] = None
3740
) -> List[Hypothesis]:
3841

3942
for idx, hyp in enumerate(hypotheses): # type: Hypothesis
@@ -49,6 +52,9 @@ def pack_hypotheses(
4952
if hyp.dec_state is not None:
5053
hyp.dec_state = _states_to_device(hyp.dec_state)
5154

55+
if xatt_scores_list is not None:
56+
hyp.xatt_scores = [xatt_layer[idx] for xatt_layer in xatt_scores_list]
57+
5258
return hypotheses
5359

5460

@@ -231,7 +237,7 @@ def forward(
231237
self.transformer_decoder.eval()
232238
self.log_softmax_module.eval()
233239

234-
topk_hypotheses, beam_scores, best_hypo = self.beam_search(
240+
topk_hypotheses, beam_scores, best_hypo, xatt_scores_list = self.beam_search(
235241
encoder_hidden_states=encoder_hidden_states,
236242
encoder_input_mask=encoder_input_mask,
237243
decoder_input_ids=decoder_input_ids,
@@ -251,11 +257,13 @@ def forward(
251257
else:
252258
beam_scores = [None for _ in range(len(best_hypo))]
253259
best_hypo = best_hypo.detach().cpu()
260+
if xatt_scores_list is not None:
261+
xatt_scores_list = [xatt_layer.detach().cpu() for xatt_layer in xatt_scores_list]
254262
hypotheses = [
255263
Hypothesis(score=0.0, y_sequence=[], timestamp=[]) for _ in range(encoder_hidden_states.shape[0])
256264
]
257265
# Pack results into Hypotheses
258-
packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores)
266+
packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores, xatt_scores_list)
259267
self.format_hypotheses(packed_result, decoder_input_ids)
260268

261269
self.transformer_decoder.train()

nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def pack_hypotheses(
3333
beam_hypotheses: torch.Tensor,
3434
scores: List[Optional[float]],
3535
step_confidence: Optional[torch.Tensor] = None,
36+
xatt_scores: Optional[List[torch.Tensor]] = None,
3637
) -> List[Hypothesis]:
3738

3839
for idx, hyp in enumerate(hypotheses): # type: Hypothesis
@@ -52,6 +53,9 @@ def pack_hypotheses(
5253
if hyp.dec_state is not None:
5354
hyp.dec_state = _states_to_device(hyp.dec_state)
5455

56+
if xatt_scores is not None:
57+
hyp.xatt_scores = [xatt_layer[idx] for xatt_layer in xatt_scores]
58+
5559
return hypotheses
5660

5761

@@ -192,7 +196,7 @@ def forward(
192196
self.transformer_decoder.eval()
193197
self.log_softmax_module.eval()
194198

195-
best_hypo, topk_hypotheses, step_confidence = self.greedy_search(
199+
best_hypo, topk_hypotheses, step_confidence, xatt_scores_list = self.greedy_search(
196200
encoder_hidden_states=encoder_hidden_states,
197201
encoder_input_mask=encoder_input_mask,
198202
decoder_input_ids=decoder_input_ids,
@@ -202,23 +206,32 @@ def forward(
202206
topk_hypotheses = [x.detach().cpu() for x in topk_hypotheses] # each item is [beam, seq_len]
203207
beam_scores = [[None] * self.n_samples for _ in topk_hypotheses] # each item is [beam,]
204208
packed_result = []
209+
if xatt_scores_list is not None:
210+
xatt_scores_list = [
211+
xatt_layer.view(len(topk_hypotheses), -1, *xatt_layer.shape[1:]).detach().cpu()
212+
for xatt_layer in xatt_scores_list]
205213
for i in range(len(topk_hypotheses)):
206214
# Pack results into Hypotheses
207215
hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestamp=[]) for _ in range(self.n_samples)]
208216
self.format_hypotheses(hypotheses, decoder_input_ids)
217+
topk_xatt_scores = None
218+
if xatt_scores_list is not None:
219+
topk_xatt_scores = [xatt_layer[i] for xatt_layer in xatt_scores_list]
209220
packed_result.append(
210221
NBestHypotheses(
211-
pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i]), step_confidence
222+
pack_hypotheses(
223+
hypotheses, topk_hypotheses[i], beam_scores[i], step_confidence, topk_xatt_scores)
212224
)
213225
)
214226
else:
215227
beam_scores = [None for _ in range(len(best_hypo))]
216228
best_hypo = best_hypo.cpu()
229+
xatt_scores_list = [xatt_scores_layer.detach().cpu() for xatt_scores_layer in xatt_scores_list]
217230
hypotheses = [
218231
Hypothesis(score=0.0, y_sequence=[], timestamp=[]) for _ in range(encoder_hidden_states.shape[0])
219232
]
220233
# Pack results into Hypotheses
221-
packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores, step_confidence)
234+
packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores, step_confidence, xatt_scores_list)
222235
self.format_hypotheses(packed_result, decoder_input_ids)
223236

224237
self.transformer_decoder.train()
@@ -256,6 +269,8 @@ def format_hypotheses(self, packed_result: List[Hypothesis], decoder_input_ids:
256269
if pos < -1:
257270
hyp.y_sequence = ids[: pos + 1]
258271
hyp.token_confidence = hyp.token_confidence[: pos + 1] if hyp.token_confidence is not None else None
272+
if hyp.xatt_scores is not None:
273+
hyp.xatt_scores = [xatt_layer[:, : pos + 1, :] for xatt_layer in hyp.xatt_scores]
259274

260275

261276
@dataclass

nemo/collections/asr/parts/utils/rnnt_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ class Hypothesis:
8787
last_token (Optional): A token or batch of tokens which was predicted in the last step.
8888
8989
last_frame (Optional): Index of the last decoding step hypothesis was updated including blank token prediction.
90+
91+
xatt_scores (Optional): List of cross-attention scores for each decoder layer. Each element of the list
92+
is a Tensor of shape num heads x decoder input len x encoder output len (HxUxT).
9093
"""
9194

9295
score: float
@@ -108,6 +111,7 @@ class Hypothesis:
108111
last_token: Optional[torch.Tensor] = None
109112
token_duration: Optional[torch.Tensor] = None
110113
last_frame: Optional[int] = None
114+
xatt_scores: Optional[List[torch.Tensor]] = None
111115

112116
@property
113117
def non_blank_frame_confidence(self) -> List[float]:

0 commit comments

Comments
 (0)