Skip to content

Commit b1466d7

Browse files
authored
[NEW Feature] 新增基于hook的refined_recompute支持 (#9396)
* 代码实现refined_recompute * 更新rr的实现,新增单测测试pp和非pp * update llama and support refined recompute * update rr * update * update create_skip_config_for_refined_recompute config.num_hidden_layers * update llama pp recompute * refined recompute only support recompute_use_reentrant=False * LOD_TENSOR * typo * rr 支持qwen模型 * support RRColumnParallelLinear & RRRowParallelLinear * fix * update llm test * fix * update test_refined_recompute
1 parent f5ca96e commit b1466d7

File tree

14 files changed

+1585
-15
lines changed

14 files changed

+1585
-15
lines changed

llm/run_finetune.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
register_sequence_parallel_allreduce_hooks,
6666
)
6767
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
68+
from paddlenlp.transformers.refined_recompute import update_refined_recompute
6869
from paddlenlp.trl import SFTTrainer
6970
from paddlenlp.trl.llm_utils import (
7071
ZeroPaddingIterDatasetCallback,
@@ -146,6 +147,10 @@ def main():
146147
)
147148

148149
LlmMetaConfig.set_llm_config(model_config, training_args)
150+
model_config.refined_recompute = update_refined_recompute(
151+
training_args.refined_recompute,
152+
model_args.lora,
153+
)
149154
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm
150155

151156
# Config for model using dropout, such as GPT.

llm/run_pretrain.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
register_sequence_parallel_allreduce_hooks,
4545
)
4646
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
47+
from paddlenlp.transformers.refined_recompute import update_refined_recompute
4748
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
4849
from paddlenlp.utils.log import logger
4950
from paddlenlp.utils.tools import get_env_device
@@ -413,6 +414,9 @@ def main():
413414
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
414415
# set all llm config
415416
LlmMetaConfig.set_llm_config(config, training_args)
417+
config.refined_recompute = update_refined_recompute(
418+
training_args.refined_recompute,
419+
)
416420
config.use_fast_layer_norm = model_args.use_fast_layer_norm
417421

418422
config.seq_length = data_args.max_seq_length

paddlenlp/transformers/configuration_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ class LlmMetaConfig:
268268
"Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']",
269269
),
270270
("recompute_use_reentrant", bool, False, "recompute_use_reentrant"),
271+
# refined_recompute attributes
272+
(
273+
"refined_recompute",
274+
str,
275+
"",
276+
"refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']",
277+
),
278+
("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"),
271279
]
272280

273281
@classmethod

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def swiglu(x, y=None):
5151
except:
5252
flash_attention = None
5353

54+
from paddlenlp.transformers.refined_recompute import no_recompute
5455
from paddlenlp.transformers.ring_flash_attention import RingFlashAttention
5556

5657

@@ -174,6 +175,7 @@ def fusion_flash_attention(
174175
sequence_parallel=False,
175176
reshard_layer=None,
176177
npu_is_casual=False,
178+
skip_recompute=False,
177179
):
178180
bsz, q_len, num_heads, head_dim = query_states.shape
179181
_, kv_seq_len, _, _ = value_states.shape
@@ -257,28 +259,34 @@ def fusion_flash_attention(
257259
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)
258260

259261
if hasattr(F, "flashmask_attention"):
260-
attn_output = F.flashmask_attention(
262+
attn_output = no_recompute(
263+
F.flashmask_attention,
261264
query_states,
262265
key_states,
263266
value_states,
264267
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1),
265268
causal=True,
269+
enable=skip_recompute,
266270
)
267271
else:
268-
attn_output = F.flash_attention_with_sparse_mask(
272+
attn_output = no_recompute(
273+
F.flash_attention_with_sparse_mask,
269274
query_states,
270275
key_states,
271276
value_states,
272277
attn_mask_start_row_indices=attn_mask_startend_row_indices,
273278
is_causal=True,
279+
enable=skip_recompute,
274280
)
275281
else:
276-
attn_output = F.scaled_dot_product_attention(
282+
attn_output = no_recompute(
283+
F.scaled_dot_product_attention,
277284
query_states,
278285
key_states,
279286
value_states,
280287
attn_mask=attention_mask,
281288
is_causal=query_states.shape[1] != 1,
289+
enable=skip_recompute,
282290
)
283291
attn_weights = None
284292

paddlenlp/transformers/llama/modeling.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,15 @@
2929
from paddle.autograd import PyLayer
3030
from paddle.distributed import fleet
3131
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
32-
from paddle.distributed.fleet.utils import recompute
32+
33+
from paddlenlp.transformers.refined_recompute import (
34+
RRColumnParallelLinear,
35+
RRColumnSequenceParallelLinear,
36+
RRRowParallelLinear,
37+
RRRowSequenceParallelLinear,
38+
create_skip_config_for_refined_recompute,
39+
recompute,
40+
)
3341

3442
try:
3543
from paddle.incubate.nn.functional import fused_rotary_position_embedding
@@ -216,6 +224,7 @@ def scaled_dot_product_attention(
216224
sequence_parallel=False,
217225
reshard_layer=None,
218226
npu_is_casual=False,
227+
skip_recompute=False,
219228
):
220229
bsz, q_len, num_heads, head_dim = query_states.shape
221230
_, kv_seq_len, _, _ = value_states.shape
@@ -233,6 +242,7 @@ def scaled_dot_product_attention(
233242
sequence_parallel,
234243
reshard_layer,
235244
npu_is_casual,
245+
skip_recompute=skip_recompute,
236246
)
237247

238248
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
@@ -605,10 +615,24 @@ def __init__(self, config):
605615
if config.sequence_parallel:
606616
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
607617
RowParallelLinear = linear_utils.RowSequenceParallelLinear
618+
619+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
620+
if config.recompute and not config.recompute_use_reentrant:
621+
if config.skip_recompute_ops.get("mlp_column_ln", False):
622+
ColumnParallelLinear = RRColumnSequenceParallelLinear
623+
if config.skip_recompute_ops.get("mlp_row_ln", False):
624+
RowParallelLinear = RRRowSequenceParallelLinear
608625
else:
609626
ColumnParallelLinear = linear_utils.ColumnParallelLinear
610627
RowParallelLinear = linear_utils.RowParallelLinear
611628

629+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
630+
if config.recompute and not config.recompute_use_reentrant:
631+
if config.skip_recompute_ops.get("mlp_column_ln", False):
632+
ColumnParallelLinear = RRColumnParallelLinear
633+
if config.skip_recompute_ops.get("mlp_row_ln", False):
634+
RowParallelLinear = RRRowParallelLinear
635+
612636
if config.tensor_parallel_degree > 1:
613637
if config.fuse_attention_ffn:
614638
self.gate_up_fused_proj = ColumnParallelLinear(
@@ -719,9 +743,22 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
719743
if config.sequence_parallel:
720744
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
721745
RowParallelLinear = linear_utils.RowSequenceParallelLinear
746+
747+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
748+
if config.recompute and not config.recompute_use_reentrant:
749+
if config.skip_recompute_ops.get("attention_column_ln", False):
750+
ColumnParallelLinear = RRColumnSequenceParallelLinear
751+
if config.skip_recompute_ops.get("attention_row_ln", False):
752+
RowParallelLinear = RRRowSequenceParallelLinear
722753
else:
723754
ColumnParallelLinear = linear_utils.ColumnParallelLinear
724755
RowParallelLinear = linear_utils.RowParallelLinear
756+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
757+
if config.recompute and not config.recompute_use_reentrant:
758+
if config.skip_recompute_ops.get("attention_column_ln", False):
759+
ColumnParallelLinear = RRColumnParallelLinear
760+
if config.skip_recompute_ops.get("attention_row_ln", False):
761+
RowParallelLinear = RRRowParallelLinear
725762

726763
if config.tensor_parallel_degree > 1:
727764
if self.fuse_attention_qkv:
@@ -821,6 +858,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
821858

822859
self.attn_func = scaled_dot_product_attention
823860

861+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
862+
if (
863+
config.recompute
864+
and not config.recompute_use_reentrant
865+
and config.skip_recompute_ops.get("flash_attn", False)
866+
):
867+
self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True)
868+
824869
def _init_rope(self):
825870
if (
826871
hasattr(self.config, "rope_scaling")
@@ -1471,7 +1516,12 @@ def __init__(self, config: LlamaConfig):
14711516
)
14721517

14731518
self.layers = nn.LayerList(
1474-
[LlamaDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)]
1519+
[
1520+
LlamaDecoderLayer(
1521+
create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers
1522+
)
1523+
for i in range(config.num_hidden_layers)
1524+
]
14751525
)
14761526
self.norm = LlamaRMSNorm(config)
14771527

paddlenlp/transformers/llama/modeling_pp.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
PipelineLayer,
2323
SharedLayerDesc,
2424
)
25-
from paddle.distributed.fleet.utils import recompute
2625

2726
from paddlenlp.transformers.model_utils import PipelinePretrainedModel
27+
from paddlenlp.transformers.refined_recompute import (
28+
create_skip_config_for_refined_recompute,
29+
recompute,
30+
)
2831
from paddlenlp.utils.tools import get_env_device
2932

3033
from .modeling import (
@@ -371,7 +374,11 @@ def get_hcg():
371374

372375
for i in range(config.num_hidden_layers):
373376
self.add_sequential_layer(
374-
LayerDesc(LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers),
377+
LayerDesc(
378+
LlamaDecoderLayerPipe,
379+
config=create_skip_config_for_refined_recompute(i, config),
380+
layerwise_recompute=i not in self.no_recompute_layers,
381+
),
375382
f"llama.layers.{i}",
376383
)
377384
self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama")

paddlenlp/transformers/qwen/modeling.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,18 @@
2424
from paddle import Tensor, nn
2525
from paddle.distributed import fleet
2626
from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker
27-
from paddle.distributed.fleet.utils import recompute
2827
from paddle.utils import try_import
2928

29+
from paddlenlp.transformers.refined_recompute import (
30+
RRColumnParallelLinear,
31+
RRColumnSequenceParallelLinear,
32+
RRRowParallelLinear,
33+
RRRowSequenceParallelLinear,
34+
create_skip_config_for_refined_recompute,
35+
no_recompute,
36+
recompute,
37+
)
38+
3039
try:
3140
from paddle.incubate.nn.functional import swiglu
3241
except ImportError:
@@ -154,9 +163,22 @@ def __init__(self, config):
154163
if config.sequence_parallel:
155164
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
156165
RowParallelLinear = linear_utils.RowSequenceParallelLinear
166+
167+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
168+
if config.recompute and not config.recompute_use_reentrant:
169+
if config.skip_recompute_ops.get("attention_column_ln", False):
170+
ColumnParallelLinear = RRColumnSequenceParallelLinear
171+
if config.skip_recompute_ops.get("attention_row_ln", False):
172+
RowParallelLinear = RRRowSequenceParallelLinear
157173
else:
158174
ColumnParallelLinear = linear_utils.ColumnParallelLinear
159175
RowParallelLinear = linear_utils.RowParallelLinear
176+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
177+
if config.recompute and not config.recompute_use_reentrant:
178+
if config.skip_recompute_ops.get("attention_column_ln", False):
179+
ColumnParallelLinear = RRColumnParallelLinear
180+
if config.skip_recompute_ops.get("attention_row_ln", False):
181+
RowParallelLinear = RRRowParallelLinear
160182

161183
if config.tensor_parallel_degree > 1:
162184
if config.num_attention_heads % config.tensor_parallel_degree != 0:
@@ -227,12 +249,19 @@ def _attn(self, query, key, value, attention_mask=None):
227249
return_softmax=self.config.attn_dropout_prob > 0.0,
228250
)
229251
else:
230-
attn_output = F.scaled_dot_product_attention(
252+
skip_recompute = (
253+
self.config.recompute
254+
and not self.config.recompute_use_reentrant
255+
and self.config.skip_recompute_ops.get("flash_attn", False)
256+
)
257+
attn_output = no_recompute(
258+
F.scaled_dot_product_attention,
231259
query,
232260
key,
233261
value,
234262
attn_mask=attention_mask,
235263
is_causal=attention_mask is None,
264+
enable=skip_recompute,
236265
)
237266
attn_weights = None
238267

@@ -388,9 +417,22 @@ def __init__(self, config):
388417
if config.sequence_parallel:
389418
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
390419
RowParallelLinear = linear_utils.RowSequenceParallelLinear
420+
421+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
422+
if config.recompute and not config.recompute_use_reentrant:
423+
if config.skip_recompute_ops.get("mlp_column_ln", False):
424+
ColumnParallelLinear = RRColumnSequenceParallelLinear
425+
if config.skip_recompute_ops.get("mlp_row_ln", False):
426+
RowParallelLinear = RRRowSequenceParallelLinear
391427
else:
392428
ColumnParallelLinear = linear_utils.ColumnParallelLinear
393429
RowParallelLinear = linear_utils.RowParallelLinear
430+
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
431+
if config.recompute and not config.recompute_use_reentrant:
432+
if config.skip_recompute_ops.get("mlp_column_ln", False):
433+
ColumnParallelLinear = RRColumnParallelLinear
434+
if config.skip_recompute_ops.get("mlp_row_ln", False):
435+
RowParallelLinear = RRRowParallelLinear
394436

395437
if config.tensor_parallel_degree > 1:
396438
if self.fuse_attention_ffn:
@@ -684,7 +726,7 @@ def __init__(self, config):
684726
self.h = nn.LayerList(
685727
[
686728
QWenBlock(
687-
config,
729+
create_skip_config_for_refined_recompute(i, config),
688730
)
689731
for i in range(config.num_hidden_layers)
690732
]

paddlenlp/transformers/qwen/modeling_pp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
1919

2020
from paddlenlp.transformers.model_utils import PipelinePretrainedModel
21+
from paddlenlp.transformers.refined_recompute import (
22+
create_skip_config_for_refined_recompute,
23+
)
2124

2225
from .modeling import (
2326
QWenBlock,
@@ -170,7 +173,7 @@ def get_hcg():
170173
self.add_sequential_layer(LayerDesc(QWenEmbeddingPipe, config=config), "qwen")
171174
for i in range(config.num_hidden_layers):
172175
self.add_sequential_layer(
173-
LayerDesc(QWenBlockPipe, config=config),
176+
LayerDesc(QWenBlockPipe, config=create_skip_config_for_refined_recompute(i, config)),
174177
f"qwen.h.{i}",
175178
)
176179
self.add_sequential_layer(LayerDesc(QWenRMSNormPipe, config=config), "qwen.ln_f")

0 commit comments

Comments
 (0)