From 4719b4aea2024167d3b522bbdbe53f1a17a9e9c5 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Tue, 2 Sep 2025 18:10:56 +0800 Subject: [PATCH 1/3] cherry-pick #10501 --- paddlenlp/trainer/training_args.py | 35 ++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 251d2390fc58..4d3394a0c3bb 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1424,10 +1424,13 @@ def is_context_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - if is_context_parallel_supported(): - order = order[1:-1] + ["cp", "dp", "mp"] + if self.moe_sharding_parallel_degree > 1 and self.expert_parallel_degree > 1: + if is_context_parallel_supported(): + order = ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"] + else: + order = ["sharding", "moe_sharding", "pp", "sep", "dp", "ep", "mp"] else: - order = order[1:-1] + ["dp", "mp"] + order = ["sharding", "pp", "sep", "dp", "mp"] if is_context_parallel_supported(): hybrid_configs = { @@ -1445,6 +1448,8 @@ def is_context_parallel_supported(): "mp_degree": self.tensor_parallel_degree, "pp_degree": self.pipeline_parallel_degree, "sharding_degree": self.sharding_parallel_degree, + "moe_sharding_degree": self.moe_sharding_parallel_degree, + "ep_degree": self.expert_parallel_degree, "sep_degree": self.sep_parallel_degree if self.sep_parallel_degree > 1 else self.context_parallel_degree, @@ -1456,6 +1461,8 @@ def is_context_parallel_supported(): "mp_degree": self.tensor_parallel_degree, "pp_degree": self.pipeline_parallel_degree, "sharding_degree": self.sharding_parallel_degree, + "moe_sharding_degree": self.moe_sharding_parallel_degree, + "ep_degree": self.expert_parallel_degree, "order": order, } @@ -1587,9 +1594,6 @@ def is_context_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) - if self.expert_parallel_degree > 1: - self.add_moe_comm_group() - elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) @@ -2035,11 +2039,6 @@ def _post_init_parallel_degree(self): logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!") self.sharding = [] - if sharding_parallel_degree > 1: - assert ( - sharding_parallel_degree % expert_parallel_degree == 0 - ), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}." - self.data_parallel_degree = world_size // ( sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree ) @@ -2048,6 +2047,19 @@ def _post_init_parallel_degree(self): assert ( self.expert_tensor_parallel_degree <= 1 ), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1" + moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree) + else: + moe_sharding_parallel_degree = 1 + moe_sharding_parallel_degree = max(moe_sharding_parallel_degree, 1) + if moe_sharding_parallel_degree > 1 and self.data_parallel_degree > 1: + raise NotImplementedError( + f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: {self.data_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}." + ) + + if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1: + assert ( + moe_sharding_parallel_degree % sharding_parallel_degree == 0 + ), f"moe_sharding_parallel_degree should be divided by sharding_parallel_degree, current sharding_parallel_degree: {moe_sharding_parallel_degree}, expert_parallel_degree: {sharding_parallel_degree}." assert not ( self.data_parallel_degree > 1 and expert_parallel_degree > 1 @@ -2070,6 +2082,7 @@ def _post_init_parallel_degree(self): self.context_parallel_degree = context_parallel_degree self.expert_parallel_degree = expert_parallel_degree self.expert_tensor_parallel_degree = expert_tensor_parallel_degree + self.moe_sharding_parallel_degree = moe_sharding_parallel_degree if not self.use_hybrid_parallel: self.sharding = [] From b98e84b4acf2f827765f7f2f234c2c079c2d0d19 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Tue, 2 Sep 2025 18:12:50 +0800 Subject: [PATCH 2/3] cherry-pick #10526 --- paddlenlp/trainer/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 4d3394a0c3bb..df4fe5f9846c 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -2058,8 +2058,8 @@ def _post_init_parallel_degree(self): if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1: assert ( - moe_sharding_parallel_degree % sharding_parallel_degree == 0 - ), f"moe_sharding_parallel_degree should be divided by sharding_parallel_degree, current sharding_parallel_degree: {moe_sharding_parallel_degree}, expert_parallel_degree: {sharding_parallel_degree}." + sharding_parallel_degree % moe_sharding_parallel_degree == 0 + ), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}." assert not ( self.data_parallel_degree > 1 and expert_parallel_degree > 1 From fafdf854a5621d7bb91e4f704bf7954e5ddd34dc Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Tue, 2 Sep 2025 18:15:04 +0800 Subject: [PATCH 3/3] cherry-pick #10545 --- paddlenlp/trainer/training_args.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index df4fe5f9846c..bc0b60a3f938 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1424,7 +1424,7 @@ def is_context_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - if self.moe_sharding_parallel_degree > 1 and self.expert_parallel_degree > 1: + if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: if is_context_parallel_supported(): order = ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"] else: @@ -1448,8 +1448,6 @@ def is_context_parallel_supported(): "mp_degree": self.tensor_parallel_degree, "pp_degree": self.pipeline_parallel_degree, "sharding_degree": self.sharding_parallel_degree, - "moe_sharding_degree": self.moe_sharding_parallel_degree, - "ep_degree": self.expert_parallel_degree, "sep_degree": self.sep_parallel_degree if self.sep_parallel_degree > 1 else self.context_parallel_degree, @@ -1461,11 +1459,16 @@ def is_context_parallel_supported(): "mp_degree": self.tensor_parallel_degree, "pp_degree": self.pipeline_parallel_degree, "sharding_degree": self.sharding_parallel_degree, - "moe_sharding_degree": self.moe_sharding_parallel_degree, - "ep_degree": self.expert_parallel_degree, "order": order, } + if self.expert_parallel_degree > 1: + assert ( + self.use_expert_parallel is True and self.moe_sharding_parallel_degree >= 0 + ), f"invalid expert_parallel_degree {self.expert_parallel_degree} and use_expert_paralle:{self.use_expert_parallel}." + hybrid_configs["ep_degree"] = self.expert_parallel_degree + hybrid_configs["moe_sharding_degree"] = self.moe_sharding_parallel_degree + try: if self.split_norm_comm: hybrid_configs["split_norm_comm"] = True