1414from transformers .models .llama .modeling_llama import (apply_rotary_pos_emb ,
1515 repeat_kv , rotate_half )
1616
17+ from swift .utils import get_logger
18+
19+ logger = get_logger ()
20+
1721
1822def forward_flashattn (
1923 self ,
@@ -306,8 +310,8 @@ def forward_flashattn_inference(
306310 )) # noqa
307311
308312 kv_seq_len = k .shape [1 ]
309- if past_key_value is not None :
310- past_kv_len = past_key_value [ 0 ]. shape [ 2 ]
313+ if past_key_value is not None and len ( past_key_value ) :
314+ past_kv_len = past_key_value . seen_tokens
311315 kv_seq_len += past_kv_len
312316
313317 cos_sin = self .rotary_emb (v , seq_len = kv_seq_len )
@@ -316,15 +320,13 @@ def forward_flashattn_inference(
316320 q = q .transpose (1 , 2 )
317321 k = k .transpose (1 , 2 )
318322
319- if past_key_value is not None :
320- assert (flash_attn_version >=
321- '2.1.0' ), 'past_key_value support requires flash-attn >= 2.1.0'
322- # reuse k, v
323- k = torch .cat ([past_key_value [0 ].transpose (1 , 2 ), k ], dim = 1 )
324- v = torch .cat ([past_key_value [1 ].transpose (1 , 2 ), v ], dim = 1 )
325-
326- past_key_value = (k .transpose (1 , 2 ),
327- v .transpose (1 , 2 )) if use_cache else None
323+ if use_cache :
324+ k , v = past_key_value .update (
325+ k .transpose (1 , 2 ), v .transpose (1 , 2 ), layer_idx = self .idx )
326+ k = k .transpose (1 , 2 )
327+ v = v .transpose (1 , 2 )
328+ else :
329+ past_key_value = None
328330
329331 if attention_mask is None :
330332 output = flash_attn_func (
@@ -405,12 +407,13 @@ def forward_flashattn_inference_s2_attn(
405407
406408def patch_llama_forward (model : nn .Module , forward_function ) -> None :
407409 # Compatible with transformers device_map
408- for m in model .model .layers :
410+ for idx , m in enumerate ( model .model .layers ) :
409411 new_forward = MethodType (forward_function , m .self_attn )
410412 if hasattr (model , '_old_forward' ):
411413 m .self_attn ._old_forward = new_forward
412414 else :
413415 m .self_attn .forward = new_forward
416+ m .self_attn .idx = idx
414417
415418
416419def replace_llama_attn (model : nn .Module , use_flash_attn = True ):
@@ -425,4 +428,7 @@ def replace_llama_attn(model: nn.Module, use_flash_attn=True):
425428 _prepare_decoder_attention_mask )
426429 patch_llama_forward (model , forward_flashattn_inference_s2_attn )
427430 else :
431+ logger .warn (
432+ 'The source code of LongLoRA without flash '
433+ 'attention may has some problems, please use with careful.' )
428434 patch_llama_forward (model , forward_noflashattn )
0 commit comments