Skip to content

Commit 1060ed4

Browse files
authored
[Cherry-pick] Cherry-pick hybrid-expert parallel related PRs from fleety (#11052)
* cherry-pick #10501 * cherry-pick #10526 * cherry-pick #10545
1 parent 2064039 commit 1060ed4

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,10 +1424,13 @@ def is_context_parallel_supported():
14241424
else:
14251425
order = ["dp", "sharding", "pp", "mp"]
14261426
if self.use_expert_parallel:
1427-
if is_context_parallel_supported():
1428-
order = order[1:-1] + ["cp", "dp", "mp"]
1427+
if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1:
1428+
if is_context_parallel_supported():
1429+
order = ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"]
1430+
else:
1431+
order = ["sharding", "moe_sharding", "pp", "sep", "dp", "ep", "mp"]
14291432
else:
1430-
order = order[1:-1] + ["dp", "mp"]
1433+
order = ["sharding", "pp", "sep", "dp", "mp"]
14311434

14321435
if is_context_parallel_supported():
14331436
hybrid_configs = {
@@ -1459,6 +1462,13 @@ def is_context_parallel_supported():
14591462
"order": order,
14601463
}
14611464

1465+
if self.expert_parallel_degree > 1:
1466+
assert (
1467+
self.use_expert_parallel is True and self.moe_sharding_parallel_degree >= 0
1468+
), f"invalid expert_parallel_degree {self.expert_parallel_degree} and use_expert_paralle:{self.use_expert_parallel}."
1469+
hybrid_configs["ep_degree"] = self.expert_parallel_degree
1470+
hybrid_configs["moe_sharding_degree"] = self.moe_sharding_parallel_degree
1471+
14621472
try:
14631473
if self.split_norm_comm:
14641474
hybrid_configs["split_norm_comm"] = True
@@ -1587,9 +1597,6 @@ def is_context_parallel_supported():
15871597
fleet.init(is_collective=True, strategy=strategy)
15881598
logger.info(strategy)
15891599

1590-
if self.expert_parallel_degree > 1:
1591-
self.add_moe_comm_group()
1592-
15931600
elif self.enable_auto_parallel:
15941601
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
15951602
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
@@ -2035,11 +2042,6 @@ def _post_init_parallel_degree(self):
20352042
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!")
20362043
self.sharding = []
20372044

2038-
if sharding_parallel_degree > 1:
2039-
assert (
2040-
sharding_parallel_degree % expert_parallel_degree == 0
2041-
), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}."
2042-
20432045
self.data_parallel_degree = world_size // (
20442046
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
20452047
)
@@ -2048,6 +2050,19 @@ def _post_init_parallel_degree(self):
20482050
assert (
20492051
self.expert_tensor_parallel_degree <= 1
20502052
), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
2053+
moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree)
2054+
else:
2055+
moe_sharding_parallel_degree = 1
2056+
moe_sharding_parallel_degree = max(moe_sharding_parallel_degree, 1)
2057+
if moe_sharding_parallel_degree > 1 and self.data_parallel_degree > 1:
2058+
raise NotImplementedError(
2059+
f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: {self.data_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}."
2060+
)
2061+
2062+
if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1:
2063+
assert (
2064+
sharding_parallel_degree % moe_sharding_parallel_degree == 0
2065+
), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}."
20512066

20522067
assert not (
20532068
self.data_parallel_degree > 1 and expert_parallel_degree > 1
@@ -2070,6 +2085,7 @@ def _post_init_parallel_degree(self):
20702085
self.context_parallel_degree = context_parallel_degree
20712086
self.expert_parallel_degree = expert_parallel_degree
20722087
self.expert_tensor_parallel_degree = expert_tensor_parallel_degree
2088+
self.moe_sharding_parallel_degree = moe_sharding_parallel_degree
20732089

20742090
if not self.use_hybrid_parallel:
20752091
self.sharding = []

0 commit comments

Comments
 (0)