@@ -378,9 +378,23 @@ def forward(
378
378
(query_layer , key_layer , value_layer ) = self ._split_heads (fused_qkv )
379
379
380
380
batch_size , q_length , _ , _ = query_layer .shape
381
+
382
+ if layer_past is not None :
383
+ past_key , past_value = layer_past
384
+ # concatenate along seq_length dimension:
385
+ # - key: [batch_size, kv_length, self.num_heads, head_dim]
386
+ # - value: [batch_size, kv_length, self.num_heads, head_dim]
387
+ key_layer = paddle .concat ((past_key , key_layer ), axis = 1 )
388
+ value_layer = paddle .concat ((past_value , value_layer ), axis = 1 )
389
+
390
+ if use_cache is True :
391
+ present = (key_layer , value_layer )
392
+ else :
393
+ present = None
394
+
381
395
version = paddle .version .full_version
382
396
version_check = True
383
- if version != "0.0.0" and version <= "2.5.2" :
397
+ if self . config . use_flash_attention and version != "0.0.0" and version <= "2.5.2" :
384
398
logger .warning (
385
399
"PaddlePaddle version 2.5.3 or higher is required, please upgrade your PaddlePaddle to 2.5.3 or other higher version."
386
400
)
@@ -397,46 +411,19 @@ def forward(
397
411
key_states ,
398
412
value_states ,
399
413
attn_mask = attention_mask ,
414
+ dropout_p = self .config .attention_dropout ,
415
+ training = self .training ,
400
416
is_causal = False ,
401
417
)
402
418
attn_weights = None
403
419
# [batch_size, seq_len, num_heads, head_dim] = > [batch_size, seq_len, hidden_size]
404
420
attn_output = attn_output .reshape ([attn_output .shape [0 ], attn_output .shape [1 ], - 1 ])
405
421
output_tensor = self .dense (attn_output )
406
422
407
- query_layer = query_layer .transpose ([0 , 2 , 1 , 3 ])
408
- key_layer = key_layer .transpose ([0 , 2 , 3 , 1 ])
409
- value_layer = value_layer .transpose ([0 , 2 , 1 , 3 ])
410
- if layer_past is not None :
411
- past_key , past_value = layer_past
412
- # concatenate along seq_length dimension:
413
- # - key: [batch_size, self.num_heads, head_dim, kv_length]
414
- # - value: [batch_size, self.num_heads, kv_length, head_dim]
415
- key_layer = paddle .concat ((past_key , key_layer ), axis = 3 )
416
- value_layer = paddle .concat ((past_value , value_layer ), axis = 2 )
417
-
418
- if use_cache :
419
- present = (key_layer , value_layer )
420
- else :
421
- present = None
422
423
else :
423
-
424
424
query_layer = query_layer .transpose ([0 , 2 , 1 , 3 ])
425
425
key_layer = key_layer .transpose ([0 , 2 , 3 , 1 ])
426
426
value_layer = value_layer .transpose ([0 , 2 , 1 , 3 ])
427
- if layer_past is not None :
428
- past_key , past_value = layer_past
429
- # concatenate along seq_length dimension:
430
- # - key: [batch_size, self.num_heads, head_dim, kv_length]
431
- # - value: [batch_size, self.num_heads, kv_length, head_dim]
432
- key_layer = paddle .concat ((past_key , key_layer ), axis = 3 )
433
- value_layer = paddle .concat ((past_value , value_layer ), axis = 2 )
434
-
435
- if use_cache is True :
436
- present = (key_layer , value_layer )
437
- else :
438
- present = None
439
-
440
427
_ , _ , _ , kv_length = key_layer .shape
441
428
442
429
query_layer = query_layer .reshape ([batch_size * self .num_heads , q_length , self .head_dim ])
@@ -449,7 +436,6 @@ def forward(
449
436
attention_scores = baddbmm (
450
437
alibi , batch1 = query_layer , batch2 = key_layer , beta = self .beta , alpha = self .inv_norm_factor
451
438
)
452
-
453
439
# change view to [batch_size, num_heads, q_length, kv_length]
454
440
# attention_scores = matmul_result.reshape([batch_size, self.num_heads, q_length, kv_length])
455
441
@@ -949,14 +935,13 @@ def forward(
949
935
seq_length_with_past = seq_length
950
936
past_key_values_length = 0
951
937
if past_key_values [0 ] is not None :
952
- past_key_values_length = past_key_values [0 ][0 ].shape [3 ]
938
+ past_key_values_length = past_key_values [0 ][0 ].shape [1 ]
953
939
seq_length_with_past = seq_length_with_past + past_key_values_length
954
940
955
941
if attention_mask is None :
956
942
attention_mask = paddle .ones ([batch_size , seq_length_with_past ], dtype = "bool" )
957
943
elif attention_mask .dtype != paddle .bool :
958
944
attention_mask = paddle .cast (attention_mask , "bool" )
959
-
960
945
if len (attention_mask .shape ) > 2 :
961
946
_attention_mask = paddle .ones ([batch_size , seq_length_with_past ], dtype = "bool" )
962
947
alibi = build_alibi_tensor (_attention_mask , self .config .n_head , dtype = hidden_states .dtype )
0 commit comments