Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,10 +1424,13 @@ def is_context_parallel_supported():
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_expert_parallel:
if is_context_parallel_supported():
order = order[1:-1] + ["cp", "dp", "mp"]
if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1:
if is_context_parallel_supported():
order = ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"]
else:
order = ["sharding", "moe_sharding", "pp", "sep", "dp", "ep", "mp"]
else:
order = order[1:-1] + ["dp", "mp"]
order = ["sharding", "pp", "sep", "dp", "mp"]

if is_context_parallel_supported():
hybrid_configs = {
Expand Down Expand Up @@ -1459,6 +1462,13 @@ def is_context_parallel_supported():
"order": order,
}

if self.expert_parallel_degree > 1:
assert (
self.use_expert_parallel is True and self.moe_sharding_parallel_degree >= 0
), f"invalid expert_parallel_degree {self.expert_parallel_degree} and use_expert_paralle:{self.use_expert_parallel}."
hybrid_configs["ep_degree"] = self.expert_parallel_degree
hybrid_configs["moe_sharding_degree"] = self.moe_sharding_parallel_degree

try:
if self.split_norm_comm:
hybrid_configs["split_norm_comm"] = True
Expand Down Expand Up @@ -1587,9 +1597,6 @@ def is_context_parallel_supported():
fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

if self.expert_parallel_degree > 1:
self.add_moe_comm_group()

elif self.enable_auto_parallel:
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
Expand Down Expand Up @@ -2035,11 +2042,6 @@ def _post_init_parallel_degree(self):
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!")
self.sharding = []

if sharding_parallel_degree > 1:
assert (
sharding_parallel_degree % expert_parallel_degree == 0
), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}."

self.data_parallel_degree = world_size // (
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
)
Expand All @@ -2048,6 +2050,19 @@ def _post_init_parallel_degree(self):
assert (
self.expert_tensor_parallel_degree <= 1
), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree)
else:
moe_sharding_parallel_degree = 1
moe_sharding_parallel_degree = max(moe_sharding_parallel_degree, 1)
if moe_sharding_parallel_degree > 1 and self.data_parallel_degree > 1:
raise NotImplementedError(
f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: {self.data_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}."
)

if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1:
assert (
sharding_parallel_degree % moe_sharding_parallel_degree == 0
), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}."

assert not (
self.data_parallel_degree > 1 and expert_parallel_degree > 1
Expand All @@ -2070,6 +2085,7 @@ def _post_init_parallel_degree(self):
self.context_parallel_degree = context_parallel_degree
self.expert_parallel_degree = expert_parallel_degree
self.expert_tensor_parallel_degree = expert_tensor_parallel_degree
self.moe_sharding_parallel_degree = moe_sharding_parallel_degree

if not self.use_hybrid_parallel:
self.sharding = []
Expand Down
Loading