29
29
from paddle .autograd import PyLayer
30
30
from paddle .distributed import fleet
31
31
from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
32
- from paddle .distributed .fleet .utils import recompute
32
+
33
+ from paddlenlp .transformers .refined_recompute import (
34
+ RRColumnParallelLinear ,
35
+ RRColumnSequenceParallelLinear ,
36
+ RRRowParallelLinear ,
37
+ RRRowSequenceParallelLinear ,
38
+ create_skip_config_for_refined_recompute ,
39
+ recompute ,
40
+ )
33
41
34
42
try :
35
43
from paddle .incubate .nn .functional import fused_rotary_position_embedding
@@ -216,6 +224,7 @@ def scaled_dot_product_attention(
216
224
sequence_parallel = False ,
217
225
reshard_layer = None ,
218
226
npu_is_casual = False ,
227
+ skip_recompute = False ,
219
228
):
220
229
bsz , q_len , num_heads , head_dim = query_states .shape
221
230
_ , kv_seq_len , _ , _ = value_states .shape
@@ -233,6 +242,7 @@ def scaled_dot_product_attention(
233
242
sequence_parallel ,
234
243
reshard_layer ,
235
244
npu_is_casual ,
245
+ skip_recompute = skip_recompute ,
236
246
)
237
247
238
248
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
@@ -605,10 +615,24 @@ def __init__(self, config):
605
615
if config .sequence_parallel :
606
616
ColumnParallelLinear = linear_utils .ColumnSequenceParallelLinear
607
617
RowParallelLinear = linear_utils .RowSequenceParallelLinear
618
+
619
+ # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
620
+ if config .recompute and not config .recompute_use_reentrant :
621
+ if config .skip_recompute_ops .get ("mlp_column_ln" , False ):
622
+ ColumnParallelLinear = RRColumnSequenceParallelLinear
623
+ if config .skip_recompute_ops .get ("mlp_row_ln" , False ):
624
+ RowParallelLinear = RRRowSequenceParallelLinear
608
625
else :
609
626
ColumnParallelLinear = linear_utils .ColumnParallelLinear
610
627
RowParallelLinear = linear_utils .RowParallelLinear
611
628
629
+ # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
630
+ if config .recompute and not config .recompute_use_reentrant :
631
+ if config .skip_recompute_ops .get ("mlp_column_ln" , False ):
632
+ ColumnParallelLinear = RRColumnParallelLinear
633
+ if config .skip_recompute_ops .get ("mlp_row_ln" , False ):
634
+ RowParallelLinear = RRRowParallelLinear
635
+
612
636
if config .tensor_parallel_degree > 1 :
613
637
if config .fuse_attention_ffn :
614
638
self .gate_up_fused_proj = ColumnParallelLinear (
@@ -719,9 +743,22 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
719
743
if config .sequence_parallel :
720
744
ColumnParallelLinear = linear_utils .ColumnSequenceParallelLinear
721
745
RowParallelLinear = linear_utils .RowSequenceParallelLinear
746
+
747
+ # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
748
+ if config .recompute and not config .recompute_use_reentrant :
749
+ if config .skip_recompute_ops .get ("attention_column_ln" , False ):
750
+ ColumnParallelLinear = RRColumnSequenceParallelLinear
751
+ if config .skip_recompute_ops .get ("attention_row_ln" , False ):
752
+ RowParallelLinear = RRRowSequenceParallelLinear
722
753
else :
723
754
ColumnParallelLinear = linear_utils .ColumnParallelLinear
724
755
RowParallelLinear = linear_utils .RowParallelLinear
756
+ # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
757
+ if config .recompute and not config .recompute_use_reentrant :
758
+ if config .skip_recompute_ops .get ("attention_column_ln" , False ):
759
+ ColumnParallelLinear = RRColumnParallelLinear
760
+ if config .skip_recompute_ops .get ("attention_row_ln" , False ):
761
+ RowParallelLinear = RRRowParallelLinear
725
762
726
763
if config .tensor_parallel_degree > 1 :
727
764
if self .fuse_attention_qkv :
@@ -821,6 +858,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
821
858
822
859
self .attn_func = scaled_dot_product_attention
823
860
861
+ # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
862
+ if (
863
+ config .recompute
864
+ and not config .recompute_use_reentrant
865
+ and config .skip_recompute_ops .get ("flash_attn" , False )
866
+ ):
867
+ self .attn_func = partial (scaled_dot_product_attention , skip_recompute = True )
868
+
824
869
def _init_rope (self ):
825
870
if (
826
871
hasattr (self .config , "rope_scaling" )
@@ -1471,7 +1516,12 @@ def __init__(self, config: LlamaConfig):
1471
1516
)
1472
1517
1473
1518
self .layers = nn .LayerList (
1474
- [LlamaDecoderLayer (config , i not in self .no_recompute_layers ) for i in range (config .num_hidden_layers )]
1519
+ [
1520
+ LlamaDecoderLayer (
1521
+ create_skip_config_for_refined_recompute (i , config ), i not in self .no_recompute_layers
1522
+ )
1523
+ for i in range (config .num_hidden_layers )
1524
+ ]
1475
1525
)
1476
1526
self .norm = LlamaRMSNorm (config )
1477
1527
0 commit comments