Skip to content

Commit 0d4a166

Browse files
authored
Merge pull request #17 from ctlllll/main
Fix save medusa_head
2 parents b50cb41 + 70577b7 commit 0d4a166

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

medusa/train/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,10 @@ def train():
386386
# trainer.save_state()
387387
# safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
388388
# Save MedusaHead seperately
389-
if hasattr(model, "module"):
390-
lm_head = model.module.medusa_head
389+
if hasattr(medusa_lm_head, "module"):
390+
lm_head = medusa_lm_head.module.medusa_head
391391
else:
392-
lm_head = model.medusa_head
392+
lm_head = medusa_lm_head.medusa_head
393393

394394
# Save Medusa heads
395395
torch.save(

0 commit comments

Comments
 (0)