@@ -162,6 +162,10 @@ def main():
162
162
model_config ._attn_implementation = model_args .attn_impl
163
163
model_config .moe_subbatch_token_num = model_args .moe_subbatch_token_num
164
164
model_config .gradient_accumulation_steps = training_args .gradient_accumulation_steps
165
+ model_config .using_flex_token = model_args .using_flex_token
166
+ model_config .using_fake_gate = model_args .using_fake_gate
167
+ model_config .moe_subbatch_token_num = model_args .moe_subbatch_token_num
168
+ model_config .aux_loss_alpha = model_args .aux_loss_alpha
165
169
logger .info (f"Final model config: { model_config } " )
166
170
logger .info ("Creating model" )
167
171
@@ -172,11 +176,6 @@ def main():
172
176
173
177
model_class = AutoModelForCausalLMPipe
174
178
175
- model_config .using_flex_token = model_args .using_flex_token
176
- model_config .using_fake_gate = model_args .using_fake_gate
177
- model_config .moe_subbatch_token_num = model_args .moe_subbatch_token_num
178
- model_config .aux_loss_alpha = model_args .aux_loss_alpha
179
-
180
179
if model_args .continue_training and not training_args .autotuner_benchmark :
181
180
model = model_class .from_pretrained (
182
181
model_args .model_name_or_path ,
@@ -313,7 +312,8 @@ def neft_post_hook(module, input, output):
313
312
if training_args .use_expert_parallel :
314
313
callbacks += [MoeExpertsGradScaleCallback (training_args )]
315
314
316
- print ("callbacks:" , callbacks , flush = True )
315
+ logger .info ("callbacks:" , callbacks , flush = True )
316
+
317
317
trainer = SFTTrainer (
318
318
model = model ,
319
319
args = training_args ,
0 commit comments