Skip to content

Commit 2064039

Browse files
authored
Cherry-pick PRs from incubate/paddlenlp_fleety (#11049)
* add offload opt util (#10607) * fix offload optimizer (#10608) * fix sharding reshard bug (#10613) * reset ema (#10790) * Move fused_quanted_ops and token_dispatcher_utils to FleetY (#10803) * move setup_fp8.py (#10820) * Add tokens_zip_unique_add_subbatch and merge_subbatch_cast ops (#10822) * add tokens_zip_unique_add_subbatch and merge_subbatch_cast ops * enhance ut * enhance ut again * remove duplicate codes * enhance ut again * Sharding reshard supports mismatch parameter name (#10479) * Sharding Reshard Supports N -> 1 (#10532) * cherry-pick #10531 * support tpdp-ep sharding reshard (#10568) * fix typo * revert sharding_first
1 parent 87515b4 commit 2064039

29 files changed

+6091
-163
lines changed

.pre-commit-config.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,12 @@ repos:
6262
language: python
6363
files: \.(md|markdown|rst)$
6464
pass_filenames: true
65+
66+
- repo: local
67+
hooks:
68+
- id: clang-format
69+
name: clang-format
70+
description: Format files with ClangFormat.
71+
entry: bash ./tools/codestyle/clang_format.sh -i
72+
language: system
73+
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$

paddlenlp/trainer/trainer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,13 @@ def __init__(
372372
self.optimizer_grouped_parameters = None
373373
self.sharding_io = None
374374
if self.args.should_save_sharding_stage1_model or self.args.should_load_sharding_stage1_model:
375-
self.sharding_io = ShardingIO(self.args, self.model, self.optimizer)
375+
self.sharding_io = ShardingIO(
376+
self.args,
377+
self.model,
378+
self.optimizer,
379+
remap_parameter_name=self.args.load_sharded_model_remap_parameter_name,
380+
)
381+
376382
if self.args.unified_checkpoint:
377383
self.unified_checkpoint_handler = UnifiedCheckpointHandler(self.args)
378384

@@ -805,9 +811,16 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
805811
if resume_from_checkpoint is not None:
806812
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
807813
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
814+
if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None:
815+
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
816+
else:
817+
success, err_msg = True, None
808818
if os.path.exists(path):
809-
logger.info(f"ZCC EMA load from {path}")
810-
self.zcc_manager.set_ema_state_dict(path)
819+
if success:
820+
logger.info(f"ZCC EMA load from {path}")
821+
self.zcc_manager.set_ema_state_dict(path)
822+
else:
823+
logger.info(f"ZCC EMA does not load {path} because {err_msg}")
811824
else:
812825
logger.info(f"ZCC EMA state dict not found, in: {path}")
813826

paddlenlp/trainer/training_args.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,11 @@ class TrainingArguments:
633633
},
634634
)
635635

636+
load_sharded_model_remap_parameter_name: bool = field(
637+
default=False,
638+
metadata={"help": "Whether to remap parameter name when load_sharded_model = true."},
639+
)
640+
636641
tensor_parallel_degree: int = field(
637642
default=-1,
638643
metadata={
@@ -2039,6 +2044,11 @@ def _post_init_parallel_degree(self):
20392044
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
20402045
)
20412046

2047+
if expert_parallel_degree > 1:
2048+
assert (
2049+
self.expert_tensor_parallel_degree <= 1
2050+
), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
2051+
20422052
assert not (
20432053
self.data_parallel_degree > 1 and expert_parallel_degree > 1
20442054
), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is {self.data_parallel_degree}."
@@ -2227,6 +2237,17 @@ def pipeline_parallel_rank(self):
22272237
else:
22282238
return 0
22292239

2240+
@property
2241+
def expert_parallel_rank(self):
2242+
if self.use_hybrid_parallel:
2243+
hcg = fleet.get_hybrid_communicate_group()
2244+
if hasattr(hcg, "get_expert_parallel_rank"):
2245+
return max(hcg.get_expert_parallel_rank(), 0)
2246+
else:
2247+
return 0
2248+
else:
2249+
return 0
2250+
22302251
@property
22312252
def context_parallel_rank(self):
22322253
if self.use_hybrid_parallel:
@@ -2252,7 +2273,7 @@ def optimizer_name_suffix(self):
22522273
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
22532274
if self.sharding_parallel_degree > 1:
22542275
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
2255-
if self.use_expert_parallel:
2276+
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
22562277
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
22572278
return "_".join(name)
22582279
else:
@@ -2268,7 +2289,7 @@ def weight_name_suffix(self):
22682289
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
22692290
if self.pipeline_parallel_degree > 1:
22702291
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
2271-
if self.use_expert_parallel:
2292+
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
22722293
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
22732294
return "_".join(name)
22742295

@@ -2277,7 +2298,9 @@ def weight_name_suffix(self):
22772298
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
22782299
return None
22792300

2280-
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
2301+
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None, sharding_parallel_degree=None):
2302+
if sharding_parallel_degree is None:
2303+
sharding_parallel_degree = self.sharding_parallel_degree
22812304
if self.use_hybrid_parallel:
22822305
name = []
22832306
if self.tensor_parallel_degree > 1:
@@ -2287,12 +2310,12 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
22872310
pp_id = self.pipeline_parallel_rank
22882311
assert isinstance(pp_id, int)
22892312
name.append(self._format_name("pp", pp_id, self.pipeline_parallel_degree))
2290-
if self.sharding_parallel_degree > 1:
2313+
if sharding_parallel_degree > 1:
22912314
if shard_id is None:
22922315
shard_id = self.sharding_parallel_rank
22932316
assert isinstance(shard_id, int)
2294-
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
2295-
if self.use_expert_parallel:
2317+
name.append(self._format_name("shard", shard_id, sharding_parallel_degree))
2318+
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
22962319
if moe_id is None:
22972320
moe_id = self.data_parallel_rank
22982321
assert isinstance(moe_id, int)
@@ -2418,9 +2441,7 @@ def should_save_sharding_stage1_model(self):
24182441
def should_load_sharding_stage1_model(self):
24192442
if self.enable_auto_parallel:
24202443
return False
2421-
return (
2422-
ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.load_sharded_model
2423-
)
2444+
return self.load_sharded_model
24242445

24252446
@property
24262447
def should_load_dataset(self):
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle import _C_ops
17+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
18+
HybridParallelOptimizer,
19+
)
20+
from paddle.optimizer import Optimizer
21+
22+
from .sharding_io import to_device
23+
24+
25+
def offload(tensor):
26+
if paddle.is_compiled_with_cuda():
27+
place = paddle.CUDAPinnedPlace()
28+
else:
29+
place = paddle.CPUPlace()
30+
31+
new_tensor = to_device(tensor, place)
32+
assert new_tensor is tensor, "to_device must be inplace operation"
33+
34+
35+
def reload(tensor):
36+
new_tensor = to_device(tensor)
37+
assert new_tensor is tensor, "to_device must be inplace operation"
38+
39+
40+
def hack_offload_optimizer():
41+
# Step 1: mock _add_accumulator
42+
origin_add_accumulator = getattr(Optimizer, "_add_accumulator")
43+
44+
def new_add_accumulator(self, *args, **kwargs):
45+
x = origin_add_accumulator(self, *args, **kwargs)
46+
offload(x)
47+
return x
48+
49+
setattr(Optimizer, "_add_accumulator", new_add_accumulator)
50+
51+
# Step 2: mock _C_ops.adamw_ and _C_ops.adamw
52+
for name in ["adam_", "adamw_"]:
53+
origin_op = getattr(_C_ops, name)
54+
55+
def new_opt_op(*args):
56+
for arg in args:
57+
if isinstance(arg, paddle.Tensor):
58+
reload(arg)
59+
60+
ret = origin_op(*args)
61+
62+
for i, arg in enumerate(args):
63+
if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient
64+
offload(arg)
65+
return ret
66+
67+
setattr(_C_ops, name, new_opt_op)
68+
69+
# Step 3: mock _insert_sync
70+
opt_type = HybridParallelOptimizer
71+
origin_insert_sync = getattr(opt_type, "_insert_sync")
72+
73+
def new_insert_sync(self, sync_var, *args, **kwargs):
74+
origin_place = sync_var.place
75+
reload(sync_var)
76+
ret = origin_insert_sync(self, sync_var, *args, **kwargs)
77+
new_sync_var = to_device(sync_var, origin_place)
78+
assert new_sync_var is sync_var, "to_device must be inplace operation"
79+
return ret
80+
81+
setattr(opt_type, "_insert_sync", new_insert_sync)

paddlenlp/trainer/utils/reshard/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
SHARDING_STRATEGY_V2,
1919
NodeModelState,
2020
all_gather_state_dict,
21+
convert_opt_name_to_tname,
22+
get_moe_sharding_group,
23+
get_param_sharding_group,
2124
get_sharding_strategy,
2225
is_sharding_opt,
26+
merge_model_state,
27+
merge_opt_state,
28+
split_model_state,
29+
split_opt_state,
30+
split_structure_name_mapping,
2331
)

0 commit comments

Comments
 (0)