Skip to content

Commit 9f42fc3

Browse files
authored
Cherry-pick hybrid expert parallel sharding_metas (#2447)
1 parent 877b5a7 commit 9f42fc3

File tree

5 files changed

+227
-56
lines changed

5 files changed

+227
-56
lines changed

paddleformers/trainer/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,13 @@ def __init__(
377377
self.optimizer_grouped_parameters = None
378378
self.sharding_io = None
379379
if self.args.should_save_sharding_stage1_model or self.args.should_load_sharding_stage1_model:
380-
self.sharding_io = ShardingIO(self.args, self.model, self.optimizer)
380+
self.sharding_io = ShardingIO(
381+
self.args,
382+
self.model,
383+
self.optimizer,
384+
remap_parameter_name=self.args.load_sharded_model_remap_parameter_name,
385+
)
386+
381387
if self.args.unified_checkpoint:
382388
self.unified_checkpoint_handler = UnifiedCheckpointHandler(self.args)
383389

paddleformers/trainer/training_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,11 @@ class TrainingArguments:
627627
},
628628
)
629629

630+
load_sharded_model_remap_parameter_name: bool = field(
631+
default=False,
632+
metadata={"help": "Whether to remap parameter name when load_sharded_model = true."},
633+
)
634+
630635
tensor_parallel_degree: int = field(
631636
default=-1,
632637
metadata={

paddleformers/trainer/utils/reshard/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
SHARDING_STRATEGY_V2,
1919
NodeModelState,
2020
all_gather_state_dict,
21+
convert_opt_name_to_tname,
2122
get_moe_sharding_group,
2223
get_param_sharding_group,
2324
get_sharding_strategy,

paddleformers/trainer/utils/reshard/common.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
)
2323
from paddle.distributed.fleet.utils.log_util import logger
2424

25+
from paddleformers.utils.tools import get_env_device
26+
2527
try:
2628
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
2729
DygraphShardingOptimizerV2,
@@ -61,6 +63,49 @@ def get_sharding_strategy(optimizer):
6163
return SHARDING_STRATEGY_V1
6264

6365

66+
def convert_opt_name_to_tname(tensor_names, opt_names):
67+
tensor_names = set(tensor_names)
68+
all_names = []
69+
all_names.extend(list(tensor_names))
70+
all_names.extend(opt_names)
71+
all_names.sort()
72+
pre_t_name = ""
73+
suffix = [
74+
"_fp32_master_0_beta1_pow_acc_0",
75+
"_fp32_master_0_beta2_pow_acc_0",
76+
"_fp32_master_0_moment1_0",
77+
"_fp32_master_0_moment2_0",
78+
"_beta1_pow_acc_0",
79+
"_beta2_pow_acc_0",
80+
"_moment1_0",
81+
"_moment2_0",
82+
]
83+
opt_to_t = {}
84+
for n in all_names:
85+
if n in tensor_names:
86+
# we get a param
87+
pre_t_name = n
88+
else:
89+
assert pre_t_name
90+
opt_to_t[n] = pre_t_name
91+
92+
for t in opt_names:
93+
_find = False
94+
for s in suffix:
95+
if get_env_device() == "xpu" and t.endswith(s + ".SCALE_VALUE"):
96+
# NOTE: for xpu adamw, all optimizer state will have an extra attribute end with SCALE_VALUE.
97+
# This extra attribute won't be used, just skip it.
98+
_find = True
99+
break
100+
if t.endswith(s):
101+
logger.info(f"{t}-{t[:-len(s)]}--{t[:-len(s)] in tensor_names}")
102+
opt_to_t[t] = t[: -len(s)]
103+
_find = True
104+
break
105+
assert _find
106+
return opt_to_t
107+
108+
64109
class NodeModelState:
65110
def __init__(self, group):
66111
self._model_weights = OrderedDict()
@@ -259,43 +304,6 @@ def pack_keys(self, structure_name_mapping=None):
259304
change the key of master weights dict from param_name to (structure_name, param_name)
260305
"""
261306
# pack key for pp convert
262-
def _opt_name_to_tname(tensor_names, opt_names):
263-
tensor_names = set(tensor_names)
264-
all_names = []
265-
all_names.extend(list(tensor_names))
266-
all_names.extend(opt_names)
267-
all_names.sort()
268-
pre_t_name = ""
269-
suffix = [
270-
"_fp32_master_0_beta1_pow_acc_0",
271-
"_fp32_master_0_beta2_pow_acc_0",
272-
"_fp32_master_0_moment1_0",
273-
"_fp32_master_0_moment2_0",
274-
"_beta1_pow_acc_0",
275-
"_beta2_pow_acc_0",
276-
"_moment1_0",
277-
"_moment2_0",
278-
]
279-
opt_to_t = {}
280-
for n in all_names:
281-
if n in tensor_names:
282-
# we get a param
283-
pre_t_name = n
284-
else:
285-
assert pre_t_name
286-
opt_to_t[n] = pre_t_name
287-
288-
for t in opt_names:
289-
_find = False
290-
for s in suffix:
291-
if t.endswith(s):
292-
logger.info(f"{t}-{t[:-len(s)]}--{t[:-len(s)] in tensor_names}")
293-
opt_to_t[t] = t[: -len(s)]
294-
_find = True
295-
break
296-
assert _find
297-
return opt_to_t
298-
299307
if structure_name_mapping is not None:
300308
tname_to_structure_name = {v: k for (k, v) in structure_name_mapping.items()}
301309
else:
@@ -304,7 +312,7 @@ def _opt_name_to_tname(tensor_names, opt_names):
304312

305313
tensor_names = list(tname_to_structure_name.keys())
306314
opt_names = list(self._opt_state.keys())
307-
opt_name_to_tname = _opt_name_to_tname(tensor_names, opt_names)
315+
opt_name_to_tname = convert_opt_name_to_tname(tensor_names, opt_names)
308316

309317
# model state
310318
model_weights_tmp = OrderedDict()

0 commit comments

Comments
 (0)