@@ -156,7 +156,7 @@ def forward(
156156 past_key_value : Optional [Cache ] = None ,
157157 output_attentions : bool = False ,
158158 ** kwargs ,
159- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [ torch . Tensor ] ]]:
159+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Cache ]]:
160160 bsz , q_len , _ = hidden_states .size ()
161161
162162 query_states = self .q_proj (hidden_states )
@@ -175,7 +175,7 @@ def forward(
175175
176176 kv_seq_len = key_states .shape [- 2 ]
177177 if past_key_value is not None :
178- kv_seq_len += past_key_value .get_usable_length ( kv_seq_len , self .layer_idx )
178+ kv_seq_len += past_key_value .get_seq_length ( self .layer_idx )
179179 cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
180180 query_states , key_states = apply_rotary_pos_emb (
181181 query_states , key_states , cos , sin , position_ids
@@ -238,15 +238,13 @@ def forward(
238238 hidden_states : torch .Tensor ,
239239 attention_mask : Optional [torch .Tensor ] = None ,
240240 position_ids : Optional [torch .LongTensor ] = None ,
241- past_key_value : Optional [Tuple [ torch . Tensor ] ] = None ,
241+ past_key_value : Optional [Cache ] = None ,
242242 output_attentions : Optional [bool ] = False ,
243- use_cache : Optional [bool ] = False ,
244243 ** kwargs ,
245244 ) -> Tuple [
246245 torch .FloatTensor ,
247- torch .FloatTensor ,
248- Optional [torch .FloatTensor ],
249- Optional [torch .FloatTensor ],
246+ Optional [torch .Tensor ],
247+ Optional [Cache ],
250248 ]:
251249 residual = hidden_states
252250
@@ -259,7 +257,6 @@ def forward(
259257 position_ids = position_ids ,
260258 past_key_value = past_key_value ,
261259 output_attentions = output_attentions ,
262- use_cache = use_cache ,
263260 )
264261 hidden_states = residual + hidden_states
265262
@@ -272,8 +269,6 @@ def forward(
272269 if not output_attentions :
273270 self_attn_weights = None
274271
275- if not use_cache :
276- present_key_value = None
277272 return hidden_states , self_attn_weights , present_key_value
278273
279274
@@ -317,9 +312,10 @@ def forward(
317312 input_ids : torch .FloatTensor = None ,
318313 attention_mask : Optional [torch .Tensor ] = None ,
319314 position_ids : Optional [torch .LongTensor ] = None ,
320- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
315+ past_key_values : Optional [
316+ Union [Cache , tuple [tuple [torch .Tensor , torch .Tensor ]]]
317+ ] = None ,
321318 inputs_embeds : Optional [torch .FloatTensor ] = None ,
322- use_cache : Optional [bool ] = None ,
323319 output_attentions : Optional [bool ] = None ,
324320 output_hidden_states : Optional [bool ] = None ,
325321 return_dict : Optional [bool ] = None ,
@@ -335,7 +331,6 @@ def forward(
335331 if output_hidden_states is not None
336332 else self .config .output_hidden_states
337333 )
338- use_cache = use_cache if use_cache is not None else self .config .use_cache
339334
340335 return_dict = (
341336 return_dict if return_dict is not None else self .config .use_return_dict
@@ -359,17 +354,25 @@ def forward(
359354 inputs_embeds = self .embed_layer (input_ids )
360355 seq_length = inputs_embeds .shape [1 ]
361356
362- if self .gradient_checkpointing and self .training :
363- if use_cache :
364- use_cache = False
365-
366357 past_key_values_length = 0
358+ use_legacy_cache = False
367359
368- if use_cache :
360+ if past_key_values is not None :
369361 use_legacy_cache = not isinstance (past_key_values , Cache )
362+ # Converts the legacy cache which is tuple into an equivalent Cache. Used for backward compatibility.
370363 if use_legacy_cache :
371364 past_key_values = DynamicCache .from_legacy_cache (past_key_values )
372- past_key_values_length = past_key_values .get_usable_length (seq_length )
365+ # Suppose the sequence length of each layer is the same
366+ past_key_values_length = past_key_values .get_seq_length ()
367+
368+ # When training + checkpoints, caching is usually disabled (just do not transfer)
369+ if (
370+ self .gradient_checkpointing
371+ and self .training
372+ and isinstance (past_key_values , Cache )
373+ ):
374+ past_key_values = None
375+ past_key_values_length = 0
373376
374377 if position_ids is None :
375378 device = input_ids .device if input_ids is not None else inputs_embeds .device
@@ -412,7 +415,6 @@ def forward(
412415 position_ids ,
413416 past_key_values ,
414417 output_attentions ,
415- use_cache ,
416418 )
417419 else :
418420 layer_outputs = decoder_layer (
@@ -421,15 +423,14 @@ def forward(
421423 position_ids = position_ids ,
422424 past_key_value = past_key_values ,
423425 output_attentions = output_attentions ,
424- use_cache = use_cache ,
425426 )
426427
427428 hidden_states = layer_outputs [0 ]
428429
429430 if output_attentions :
430431 all_self_attns += (layer_outputs [1 ],)
431432
432- if use_cache :
433+ if isinstance ( past_key_values , Cache ) :
433434 next_decoder_cache = layer_outputs [2 ]
434435
435436 hidden_states = self .norm (hidden_states )
@@ -438,7 +439,7 @@ def forward(
438439 all_hidden_states += (hidden_states ,)
439440
440441 next_cache = None
441- if use_cache :
442+ if isinstance ( past_key_values , Cache ) :
442443 next_cache = (
443444 next_decoder_cache .to_legacy_cache ()
444445 if use_legacy_cache
@@ -484,12 +485,13 @@ def forward(
484485 input_ids : torch .FloatTensor = None ,
485486 attention_mask : Optional [torch .Tensor ] = None ,
486487 position_ids : Optional [torch .LongTensor ] = None ,
487- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
488+ past_key_values : Optional [
489+ Union [Cache , tuple [tuple [torch .Tensor , torch .Tensor ]]]
490+ ] = None ,
488491 inputs_embeds : Optional [torch .FloatTensor ] = None ,
489492 labels : Optional [torch .FloatTensor ] = None ,
490493 loss_masks : Optional [torch .FloatTensor ] = None ,
491494 mask_y : Optional [torch .FloatTensor ] = None ,
492- use_cache : Optional [bool ] = None ,
493495 output_attentions : Optional [bool ] = None ,
494496 output_hidden_states : Optional [bool ] = None ,
495497 return_dict : Optional [bool ] = None ,
@@ -525,7 +527,6 @@ def forward(
525527 position_ids = position_ids ,
526528 past_key_values = past_key_values ,
527529 inputs_embeds = inputs_embeds ,
528- use_cache = use_cache ,
529530 output_attentions = output_attentions ,
530531 output_hidden_states = output_hidden_states ,
531532 return_dict = return_dict ,
@@ -604,16 +605,9 @@ def prepare_inputs_for_generation(
604605 # Omit tokens covered by past_key_values
605606 if past_key_values is not None :
606607 if isinstance (past_key_values , Cache ):
607- cache_length = past_key_values .get_seq_length ()
608- if isinstance (past_key_values , DynamicCache ):
609- past_length = past_key_values .seen_tokens
610- else :
611- past_length = cache_length
612-
613- max_cache_length = past_key_values .get_max_length ()
608+ past_length = past_key_values .get_seq_length ()
614609 else :
615- cache_length = past_length = past_key_values [0 ][0 ].shape [2 ]
616- max_cache_length = None
610+ past_length = past_key_values [0 ][0 ].shape [2 ]
617611
618612 # Keep only the unprocessed tokens:
619613 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -622,26 +616,13 @@ def prepare_inputs_for_generation(
622616 if attention_mask is not None and attention_mask .shape [1 ] > (
623617 input_ids .shape [1 ] // self .config .input_token_len
624618 ):
625- input_ids = input_ids [
626- :,
627- - (attention_mask .shape [1 ] - past_length )
628- * self .config .input_token_len :,
629- ]
619+ input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_length ) :]
630620 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
631621 # input_ids based on the past_length.
632622 elif past_length < (input_ids .shape [1 ] // self .config .input_token_len ):
633623 input_ids = input_ids [:, past_length * self .config .input_token_len :]
634624 # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
635625
636- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
637- if (
638- max_cache_length is not None
639- and attention_mask is not None
640- and cache_length + (input_ids .shape [1 ] // self .config .input_token_len )
641- > max_cache_length
642- ):
643- attention_mask = attention_mask [:, - max_cache_length :]
644-
645626 position_ids = kwargs .get ("position_ids" , None )
646627 if attention_mask is not None and position_ids is None :
647628 # create position_ids on the fly for batch generation
@@ -662,7 +643,6 @@ def prepare_inputs_for_generation(
662643 {
663644 "position_ids" : position_ids ,
664645 "past_key_values" : past_key_values ,
665- "use_cache" : kwargs .get ("use_cache" ),
666646 "attention_mask" : attention_mask ,
667647 "revin" : revin ,
668648 "num_samples" : num_samples ,
0 commit comments