We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents b50cb41 + 0d4a166 commit ae67991Copy full SHA for ae67991
medusa/train/train.py
@@ -386,10 +386,10 @@ def train():
386
# trainer.save_state()
387
# safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
388
# Save MedusaHead seperately
389
- if hasattr(model, "module"):
390
- lm_head = model.module.medusa_head
+ if hasattr(medusa_lm_head, "module"):
+ lm_head = medusa_lm_head.module.medusa_head
391
else:
392
- lm_head = model.medusa_head
+ lm_head = medusa_lm_head.medusa_head
393
394
# Save Medusa heads
395
torch.save(
0 commit comments