Skip to content

Commit 4c7f5d6

Browse files
authored
Merge branch 'PaddlePaddle:develop' into dev_20250126_add_pipeline_for_moe
2 parents b104eaa + 54b8882 commit 4c7f5d6

File tree

3 files changed

+84
-8
lines changed

3 files changed

+84
-8
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,13 @@ def _save_checkpoint(self, model, metrics=None):
713713
for key, value in model.state_dict("opt").items()
714714
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
715715
}
716+
model_state_dict = model.state_dict("param")
717+
if self.args.should_save_model_with_tensor_fusion:
718+
model_state_dict = self._convert_state_dict_for_saving_tensor_fusion_ckpt(model_state_dict)
719+
opt_state_dict = self._convert_state_dict_for_saving_tensor_fusion_ckpt(opt_state_dict)
720+
716721
state_dict = {
717-
MODEL_NAME: model.state_dict("param"),
722+
MODEL_NAME: model_state_dict,
718723
OPTIMIZER_NAME: opt_state_dict,
719724
}
720725
else:
@@ -854,6 +859,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
854859
for key, value in self.model_wrapped.state_dict("opt").items()
855860
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
856861
}
862+
if self.args.should_load_model_with_tensor_fusion:
863+
model_state_dict = self._convert_state_dict_for_loading_tensor_fusion_ckpt(model_state_dict)
864+
optim_state_dict = self._convert_state_dict_for_loading_tensor_fusion_ckpt(optim_state_dict)
857865
else:
858866
model_state_dict = self.model_wrapped.state_dict()
859867
optim_state_dict = self.optimizer.state_dict()
@@ -888,7 +896,36 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
888896
self._load_ckpt_func(state_dict, ckpt_path)
889897

890898
if self.args.to_static:
899+
if self.args.should_load_model_with_tensor_fusion:
900+
model_state_dict = self._convert_state_dict_for_loading_model_with_tensor_fusion(model_state_dict)
901+
optim_state_dict = self._convert_state_dict_for_loading_model_with_tensor_fusion(optim_state_dict)
902+
891903
self.model_wrapped.set_state_dict(model_state_dict)
892904
self.model_wrapped.set_state_dict(optim_state_dict)
893905
# release memory
894906
del state_dict
907+
908+
def _convert_state_dict_for_loading_tensor_fusion_ckpt(self, state_dict):
909+
if self.args.load_model_with_sharding_tensor_fusion:
910+
logger.info("load sharding tensor fusion unbalanced model")
911+
state_dict = self.model_wrapped._convert_state_dict_with_rank_unique_name(state_dict)
912+
else:
913+
logger.info("load sharding tensor fusion balanced model")
914+
state_dict = self.model_wrapped._convert_state_dict_without_tensor_fusion_param(state_dict)
915+
return state_dict
916+
917+
def _convert_state_dict_for_loading_model_with_tensor_fusion(self, state_dict):
918+
if self.args.load_model_with_sharding_tensor_fusion:
919+
state_dict = self.model_wrapped._convert_state_dict_with_origin_name(state_dict)
920+
else:
921+
state_dict = self.model_wrapped._convert_state_dict_with_tensor_fusion_param(state_dict)
922+
return state_dict
923+
924+
def _convert_state_dict_for_saving_tensor_fusion_ckpt(self, state_dict):
925+
if self.args.save_model_with_sharding_tensor_fusion:
926+
logger.info("save sharding tensor fusion unbalanced model")
927+
state_dict = self.model_wrapped._convert_state_dict_with_rank_unique_name(state_dict)
928+
else:
929+
logger.info("save sharding tensor fusion balanced model")
930+
state_dict = self.model_wrapped._convert_state_dict_without_tensor_fusion_param(state_dict)
931+
return state_dict

paddlenlp/trainer/auto_training_args.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import json
1515
from dataclasses import dataclass, field
1616

17-
from .trainer_utils import split_parallel_config
17+
from .trainer_utils import ShardingOption, split_parallel_config
1818
from .training_args import TrainingArguments
1919
from .utils import add_start_docstrings
2020

@@ -52,6 +52,29 @@ class AutoTrainingArguments(TrainingArguments):
5252
metadata={"help": "Weather to use auto_parallel intermediate api"},
5353
)
5454
refined_ops_patterns: str = field(default=None, metadata={"help": "The pattern of refined recompute."})
55+
load_model_with_sharding_tensor_fusion: bool = field(
56+
default=False,
57+
metadata={
58+
"help": (
59+
"When using sharding stage1, enabling tensor fusion, and setting `load_model_with_sharding_tensor_fusion` to `True`, "
60+
"the model is loaded with unbalanced weights, meaning that the model weights are stored in an unbalanced format to avoid "
61+
"additional memory overhead. If set to `False`, the model will be loaded with balanced weights, which may increase memory "
62+
"consumption. This setting is only available in auto parallel to_static mode."
63+
)
64+
},
65+
)
66+
save_model_with_sharding_tensor_fusion: bool = field(
67+
default=False,
68+
metadata={
69+
"help": (
70+
"When using sharding stage1 and enabling tensor fusion, setting `save_model_with_sharding_tensor_fusion` to `True` "
71+
"saves the model with unbalanced weights, which helps avoid additional memory consumption. Setting it to `False` "
72+
"saves the model with balanced weights, which may increase memory usage but ensures uniform parameter distribution. "
73+
"This option allows flexibility in choosing the save format based on memory requirements. "
74+
"This setting is only available in auto parallel to_static mode."
75+
)
76+
},
77+
)
5578

5679
def __post_init__(self):
5780
super().__post_init__()
@@ -89,3 +112,13 @@ def __post_init__(self):
89112
recompute.refined_ops_patterns = (
90113
self.refined_ops_patterns if self.refined_ops_patterns is not None else []
91114
)
115+
116+
@property
117+
def should_load_model_with_tensor_fusion(self):
118+
return (
119+
self.enable_auto_parallel
120+
and self.to_static
121+
and ShardingOption.SHARD_OP in self.sharding
122+
and self.sharding_parallel_degree > 1
123+
and "enable_tensor_fusion" in self.sharding_parallel_config
124+
)

paddlenlp/trainer/training_args.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ class TrainingArguments:
619619
)
620620
},
621621
)
622+
622623
tensor_parallel_degree: int = field(
623624
default=-1,
624625
metadata={
@@ -740,7 +741,6 @@ class TrainingArguments:
740741
"enable_stage2_overlap, overlap stage2 NCCL communication with computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap and no other sync could be called during the training for broadcast overlap\n"
741742
"enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap.\n"
742743
"enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap.\n"
743-
"enable_tensor_fusion_blanced_save_load, convert unbalanced optimizer state to balanced state when using tensor fusion strategy, which may increase the memory occupation."
744744
)
745745
},
746746
)
@@ -1671,7 +1671,6 @@ def is_segment_parallel_supported():
16711671
"enable_tensor_fusion",
16721672
"enable_overlap",
16731673
"enable_release_grads",
1674-
"enable_tensor_fusion_blanced_save_load",
16751674
]:
16761675
if x in ["enable_stage1_overlap", "enable_stage2_overlap"]:
16771676
raise ValueError(
@@ -1686,7 +1685,7 @@ def is_segment_parallel_supported():
16861685
raise ValueError(
16871686
f"Found unknown sharding mode config {x}, "
16881687
f"accpet config is enable_tensor_fusion, "
1689-
"enable_overlap, enable_release_grads, enable_tensor_fusion_blanced_save_load."
1688+
"enable_overlap, enable_release_grads."
16901689
)
16911690

16921691
if "enable_overlap" in sharding_parallel_config:
@@ -1696,9 +1695,6 @@ def is_segment_parallel_supported():
16961695
sharding.grad_bucket_size_numel = 210355872
16971696
sharding.enable_tensor_fusion = True
16981697

1699-
if "enable_tensor_fusion_blanced_save_load" in sharding_parallel_config:
1700-
sharding.save_unbalanced_param = False
1701-
17021698
if "enable_release_grads" in sharding_parallel_config:
17031699
sharding.release_gradients = True
17041700

@@ -2273,3 +2269,13 @@ def print_config(self, args=None, key=""):
22732269
logger.debug("{:30}: {}".format(a, v))
22742270

22752271
logger.debug("")
2272+
2273+
@property
2274+
def should_save_model_with_tensor_fusion(self):
2275+
return (
2276+
self.enable_auto_parallel
2277+
and self.to_static
2278+
and ShardingOption.SHARD_OP in self.sharding
2279+
and self.sharding_parallel_degree > 1
2280+
and "enable_tensor_fusion" in self.sharding_parallel_config
2281+
)

0 commit comments

Comments
 (0)