diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 9e86b3f9825a..b16389d4d6d3 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -169,6 +169,8 @@ def __init__( self.moe_num_experts = moe_num_experts self.capacity = capacity + self.is_tp_moe = False + self.is_dp_moe = False try: dist.fleet.get_hybrid_communicate_group() @@ -190,6 +192,7 @@ def __init__( self.moe_num_experts, self.expert_parallel_degree ) self.is_dummy_moe = False if self.expert_parallel_degree > 1 else True + self.is_dp_moe = True elif ( is_fleet_init and dist.fleet.get_hybrid_communicate_group().get_model_parallel_world_size() > 1 @@ -210,6 +213,7 @@ def __init__( ) # e.g. 单机2路tp, 那么 32 = 128/4 self.is_dummy_moe = False if self.expert_parallel_degree > 1 else True # False + self.is_tp_moe = True else: self.moe_group = None self.moe_rank = 0 @@ -250,9 +254,11 @@ def _post_init(self): for k in self.experts: if k is not None: for p in k.parameters(): - p.expert = not self.is_dummy_moe - p.no_sync = not self.is_dummy_moe - # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") + p.expert = not (self.is_tp_moe or self.is_dummy_moe) # type: ignore + p.no_sync = not (self.is_tp_moe or self.is_dummy_moe) + logger.info(f"expert param={p.name}, no-sync={p.no_sync}") + if self.is_mp_moe or self.is_dp_moe: + p.is_distributed = True def forward( self,