Skip to content

Commit a54bd7d

Browse files
committed
clean
1 parent 7305489 commit a54bd7d

File tree

1 file changed

+8
-43
lines changed

1 file changed

+8
-43
lines changed

swift/megatron/trainers/gkd_trainer.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -77,46 +77,8 @@ def train(self, train_dataset, val_dataset, data_collator):
7777
self._train_valid_test_dataset_provider = get_swift_datasets_provider(train_dataset, val_dataset)
7878
self._train_valid_test_dataset_provider.is_distributed = True
7979

80-
# Adjust EP parameters for Dense/MoE compatibility before Megatron initialization.
81-
# Megatron's validate_args requires num_experts to be set when EP > 1.
82-
# If student is Dense but user configured EP > 1 (expecting MoE teacher),
83-
# we need to reset EP to 1 for student initialization.
84-
self._adjust_ep_for_dense_student()
85-
8680
super().train(train_dataset, val_dataset, data_collator)
8781

88-
def _adjust_ep_for_dense_student(self):
89-
"""Adjust EP parameters when student is Dense but EP > 1 is configured.
90-
91-
This handles the case where:
92-
- Student is Dense (no num_experts)
93-
- User configured EP > 1 (expecting to use EP for MoE teacher)
94-
95-
We reset EP to 1 for student initialization to avoid Megatron's validation error.
96-
The original EP will be restored when loading MoE teacher model.
97-
"""
98-
args = self.args
99-
student_is_moe = getattr(args, 'num_experts', None) is not None
100-
101-
if not student_is_moe:
102-
# Student is Dense, check if EP is configured
103-
ep_size = getattr(args, 'expert_model_parallel_size', 1)
104-
etp_size = getattr(args, 'expert_tensor_parallel_size', 1)
105-
106-
if ep_size > 1 or etp_size > 1:
107-
# Save original EP settings for MoE teacher
108-
self._original_ep_size = ep_size
109-
self._original_etp_size = etp_size
110-
111-
# Reset EP to 1 for Dense student
112-
args.expert_model_parallel_size = 1
113-
args.expert_tensor_parallel_size = 1
114-
115-
# Also update extra_args if it exists
116-
if hasattr(args, 'extra_args') and args.extra_args is not None:
117-
args.extra_args['expert_model_parallel_size'] = 1
118-
args.extra_args['expert_tensor_parallel_size'] = 1
119-
12082
def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs):
12183
"""Setup model and optimizer, including teacher model.
12284
@@ -223,10 +185,9 @@ def _load_teacher_model(self, teacher_model_path: str, model_type: str):
223185

224186
# Restore original EP settings if they were saved during _adjust_ep_for_dense_student.
225187
# This allows MoE teacher to use EP > 1 even when student is Dense.
226-
if hasattr(self, '_original_ep_size'):
227-
megatron_args.expert_model_parallel_size = self._original_ep_size
228-
if hasattr(self, '_original_etp_size'):
229-
megatron_args.expert_tensor_parallel_size = self._original_etp_size
188+
if self.student_is_moe:
189+
megatron_args.expert_model_parallel_size = self.student_ep_size
190+
megatron_args.expert_tensor_parallel_size = self.student_etp_size
230191
else:
231192
# Dense teacher cannot use expert parallelism.
232193
# Reset EP to 1 to avoid "num_moe_experts must be non None to use expert-parallel" error.
@@ -795,8 +756,12 @@ def patched_validate_args(self, args, *_args, **kwargs):
795756
This is called before Megatron's validate_args, allowing us to reset EP to 1
796757
when student is Dense but EP > 1 was configured (for MoE teacher).
797758
"""
798-
if hasattr(self, '_original_ep_size') or hasattr(self, '_original_etp_size'):
759+
student_is_moe = getattr(args, 'num_experts', None) is not None
760+
if not student_is_moe:
799761
# Reset EP to 1 in Megatron args for Dense student
762+
self._original_ep_size = args.expert_model_parallel_size
763+
self._original_etp_size = args.expert_tensor_parallel_size
800764
args.expert_model_parallel_size = 1
801765
args.expert_tensor_parallel_size = 1
766+
self.student_is_moe = student_is_moe
802767
return self._origin_validate_args(args, *_args, **kwargs)

0 commit comments

Comments
 (0)