diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 251d2390fc58..bc0b60a3f938 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1424,10 +1424,13 @@ def is_context_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - if is_context_parallel_supported(): - order = order[1:-1] + ["cp", "dp", "mp"] + if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: + if is_context_parallel_supported(): + order = ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"] + else: + order = ["sharding", "moe_sharding", "pp", "sep", "dp", "ep", "mp"] else: - order = order[1:-1] + ["dp", "mp"] + order = ["sharding", "pp", "sep", "dp", "mp"] if is_context_parallel_supported(): hybrid_configs = { @@ -1459,6 +1462,13 @@ def is_context_parallel_supported(): "order": order, } + if self.expert_parallel_degree > 1: + assert ( + self.use_expert_parallel is True and self.moe_sharding_parallel_degree >= 0 + ), f"invalid expert_parallel_degree {self.expert_parallel_degree} and use_expert_paralle:{self.use_expert_parallel}." + hybrid_configs["ep_degree"] = self.expert_parallel_degree + hybrid_configs["moe_sharding_degree"] = self.moe_sharding_parallel_degree + try: if self.split_norm_comm: hybrid_configs["split_norm_comm"] = True @@ -1587,9 +1597,6 @@ def is_context_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) - if self.expert_parallel_degree > 1: - self.add_moe_comm_group() - elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) @@ -2035,11 +2042,6 @@ def _post_init_parallel_degree(self): logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!") self.sharding = [] - if sharding_parallel_degree > 1: - assert ( - sharding_parallel_degree % expert_parallel_degree == 0 - ), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}." - self.data_parallel_degree = world_size // ( sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree ) @@ -2048,6 +2050,19 @@ def _post_init_parallel_degree(self): assert ( self.expert_tensor_parallel_degree <= 1 ), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1" + moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree) + else: + moe_sharding_parallel_degree = 1 + moe_sharding_parallel_degree = max(moe_sharding_parallel_degree, 1) + if moe_sharding_parallel_degree > 1 and self.data_parallel_degree > 1: + raise NotImplementedError( + f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: {self.data_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}." + ) + + if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1: + assert ( + sharding_parallel_degree % moe_sharding_parallel_degree == 0 + ), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}." assert not ( self.data_parallel_degree > 1 and expert_parallel_degree > 1 @@ -2070,6 +2085,7 @@ def _post_init_parallel_degree(self): self.context_parallel_degree = context_parallel_degree self.expert_parallel_degree = expert_parallel_degree self.expert_tensor_parallel_degree = expert_tensor_parallel_degree + self.moe_sharding_parallel_degree = moe_sharding_parallel_degree if not self.use_hybrid_parallel: self.sharding = []