diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 2088b08..e5d4e8c 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -113,6 +113,9 @@ def batch_beam_search(self, encoder_outputs, src_masks, if is_finished.sum().item() == N*B: break + for dec_layer in self.layer_stack: + dec_layer.cross_attn.clear_states() + # Length penalty (follow GNMT) scores = scores.view(N, B) ys = ys.view(N, B, -1) @@ -177,7 +180,7 @@ def __init__(self, d_model, n_head, dropout): self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) self.cross_attn_norm = nn.LayerNorm(d_model) - self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, is_cross=True) self.mlp_norm = nn.LayerNorm(d_model) self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) @@ -212,7 +215,7 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, class DecoderMultiHeadAttention(nn.Module): - def __init__(self, d_model, n_head, dropout=0.1): + def __init__(self, d_model, n_head, dropout=0.1, is_cross = False): super().__init__() self.d_model = d_model self.n_head = n_head @@ -226,13 +229,27 @@ def __init__(self, d_model, n_head, dropout=0.1): temperature=self.d_k ** 0.5) self.fc = nn.Linear(n_head * self.d_k, d_model) self.dropout = nn.Dropout(dropout) + self.is_cross = is_cross + self.kv_proj = None - def forward(self, q, k, v, mask=None): + def clear_states(self): + self.kv_proj = None + + def forward(self, q, k, v, mask=None, cross_kv_cache=None): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) - k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) - v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + if self.is_cross: + # cross attention reuse the same k,v projection throughout decoding phase + if self.kv_proj is None: + self.kv_proj = ( + self.w_ks(k).view(bs, -1, self.n_head, self.d_k), + self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + ) + k,v = self.kv_proj + else: + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2)