2727from torch import nn
2828from transformers import T5Config
2929from transformers .activations import ACT2FN
30+ from transformers .cache_utils import EncoderDecoderCache
3031from transformers .models .t5 .modeling_t5 import (
3132 T5Attention ,
3233 T5DenseActDense ,
@@ -154,7 +155,7 @@ def forward(
154155 mask = None ,
155156 key_value_states = None ,
156157 position_bias = None ,
157- past_key_value = None ,
158+ past_key_values = None ,
158159 layer_head_mask = None ,
159160 query_length = None ,
160161 use_cache = False ,
@@ -177,38 +178,38 @@ def forward(
177178 batch_size , - 1 , self .num_attention_heads_per_partition , self .key_value_proj_dim
178179 ).transpose (1 , 2 )
179180
180- if past_key_value is not None :
181- is_updated = past_key_value .is_updated .get (self .layer_idx )
181+ # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
182+ is_updated = False
183+ if isinstance (past_key_values , EncoderDecoderCache ):
184+ is_updated = past_key_values .is_updated .get (self .layer_idx )
182185 if is_cross_attention :
183186 # after the first generated id, we can subsequently re-use all key/value_states from cache
184- curr_past_key_value = past_key_value .cross_attention_cache
187+ curr_past_key_values = past_key_values .cross_attention_cache
185188 else :
186- curr_past_key_value = past_key_value .self_attention_cache
189+ curr_past_key_values = past_key_values .self_attention_cache
190+ else :
191+ curr_past_key_values = past_key_values
187192
188193 current_states = key_value_states if is_cross_attention else hidden_states
189- if is_cross_attention and past_key_value is not None and is_updated :
194+ if is_cross_attention and past_key_values is not None and is_updated :
190195 # reuse k,v, cross_attentions
191- key_states = curr_past_key_value . key_cache [self .layer_idx ]
192- value_states = curr_past_key_value . value_cache [self .layer_idx ]
196+ key_states = curr_past_key_values . layers [self .layer_idx ]. keys
197+ value_states = curr_past_key_values . layers [self .layer_idx ]. values
193198 else :
194199 key_states = self .k (current_states )
195200 value_states = self .v (current_states )
196- key_states = key_states .view (
197- batch_size , - 1 , self .num_attention_heads_per_partition , self .key_value_proj_dim
198- ).transpose (1 , 2 )
199- value_states = value_states .view (
200- batch_size , - 1 , self .num_attention_heads_per_partition , self .key_value_proj_dim
201- ).transpose (1 , 2 )
202-
203- if past_key_value is not None :
201+ key_states = key_states .view (batch_size , - 1 , self .n_heads , self .key_value_proj_dim ).transpose (1 , 2 )
202+ value_states = value_states .view (batch_size , - 1 , self .n_heads , self .key_value_proj_dim ).transpose (1 , 2 )
203+
204+ if past_key_values is not None :
204205 # save all key/value_states to cache to be re-used for fast auto-regressive generation
205206 cache_position = cache_position if not is_cross_attention else None
206- key_states , value_states = curr_past_key_value .update (
207+ key_states , value_states = curr_past_key_values .update (
207208 key_states , value_states , self .layer_idx , {"cache_position" : cache_position }
208209 )
209210 # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
210- if is_cross_attention :
211- past_key_value .is_updated [self .layer_idx ] = True
211+ if is_cross_attention and isinstance ( past_key_values , EncoderDecoderCache ) :
212+ past_key_values .is_updated [self .layer_idx ] = True
212213
213214 # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
214215 scores = torch .matmul (query_states , key_states .transpose (3 , 2 ))
@@ -235,14 +236,9 @@ def forward(
235236 causal_mask = mask [:, :, :, : key_states .shape [- 2 ]]
236237 position_bias = position_bias + causal_mask
237238
238- if self .pruned_heads :
239- mask = torch .ones (position_bias .shape [1 ])
240- mask [list (self .pruned_heads )] = 0
241- position_bias_masked = position_bias [:, mask .bool ()]
242- else :
243- position_bias_masked = position_bias
244-
239+ position_bias_masked = position_bias
245240 scores += position_bias_masked
241+
246242 # (batch_size, n_heads, seq_length, key_length)
247243 attn_weights = nn .functional .softmax (scores .float (), dim = - 1 ).type_as (scores )
248244 attn_weights = nn .functional .dropout (attn_weights , p = self .dropout , training = self .training )
0 commit comments