@@ -77,46 +77,8 @@ def train(self, train_dataset, val_dataset, data_collator):
7777 self ._train_valid_test_dataset_provider = get_swift_datasets_provider (train_dataset , val_dataset )
7878 self ._train_valid_test_dataset_provider .is_distributed = True
7979
80- # Adjust EP parameters for Dense/MoE compatibility before Megatron initialization.
81- # Megatron's validate_args requires num_experts to be set when EP > 1.
82- # If student is Dense but user configured EP > 1 (expecting MoE teacher),
83- # we need to reset EP to 1 for student initialization.
84- self ._adjust_ep_for_dense_student ()
85-
8680 super ().train (train_dataset , val_dataset , data_collator )
8781
88- def _adjust_ep_for_dense_student (self ):
89- """Adjust EP parameters when student is Dense but EP > 1 is configured.
90-
91- This handles the case where:
92- - Student is Dense (no num_experts)
93- - User configured EP > 1 (expecting to use EP for MoE teacher)
94-
95- We reset EP to 1 for student initialization to avoid Megatron's validation error.
96- The original EP will be restored when loading MoE teacher model.
97- """
98- args = self .args
99- student_is_moe = getattr (args , 'num_experts' , None ) is not None
100-
101- if not student_is_moe :
102- # Student is Dense, check if EP is configured
103- ep_size = getattr (args , 'expert_model_parallel_size' , 1 )
104- etp_size = getattr (args , 'expert_tensor_parallel_size' , 1 )
105-
106- if ep_size > 1 or etp_size > 1 :
107- # Save original EP settings for MoE teacher
108- self ._original_ep_size = ep_size
109- self ._original_etp_size = etp_size
110-
111- # Reset EP to 1 for Dense student
112- args .expert_model_parallel_size = 1
113- args .expert_tensor_parallel_size = 1
114-
115- # Also update extra_args if it exists
116- if hasattr (args , 'extra_args' ) and args .extra_args is not None :
117- args .extra_args ['expert_model_parallel_size' ] = 1
118- args .extra_args ['expert_tensor_parallel_size' ] = 1
119-
12082 def setup_model_and_optimizer (self , model_provider_func , model_type , * _args , ** kwargs ):
12183 """Setup model and optimizer, including teacher model.
12284
@@ -223,10 +185,9 @@ def _load_teacher_model(self, teacher_model_path: str, model_type: str):
223185
224186 # Restore original EP settings if they were saved during _adjust_ep_for_dense_student.
225187 # This allows MoE teacher to use EP > 1 even when student is Dense.
226- if hasattr (self , '_original_ep_size' ):
227- megatron_args .expert_model_parallel_size = self ._original_ep_size
228- if hasattr (self , '_original_etp_size' ):
229- megatron_args .expert_tensor_parallel_size = self ._original_etp_size
188+ if self .student_is_moe :
189+ megatron_args .expert_model_parallel_size = self .student_ep_size
190+ megatron_args .expert_tensor_parallel_size = self .student_etp_size
230191 else :
231192 # Dense teacher cannot use expert parallelism.
232193 # Reset EP to 1 to avoid "num_moe_experts must be non None to use expert-parallel" error.
@@ -795,8 +756,12 @@ def patched_validate_args(self, args, *_args, **kwargs):
795756 This is called before Megatron's validate_args, allowing us to reset EP to 1
796757 when student is Dense but EP > 1 was configured (for MoE teacher).
797758 """
798- if hasattr (self , '_original_ep_size' ) or hasattr (self , '_original_etp_size' ):
759+ student_is_moe = getattr (args , 'num_experts' , None ) is not None
760+ if not student_is_moe :
799761 # Reset EP to 1 in Megatron args for Dense student
762+ self ._original_ep_size = args .expert_model_parallel_size
763+ self ._original_etp_size = args .expert_tensor_parallel_size
800764 args .expert_model_parallel_size = 1
801765 args .expert_tensor_parallel_size = 1
766+ self .student_is_moe = student_is_moe
802767 return self ._origin_validate_args (args , * _args , ** kwargs )
0 commit comments