Skip to content

Commit a4ec58e

Browse files
committed
fix eval
1 parent 1068af1 commit a4ec58e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

medusa/eval/heads_accuracy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_accuracies(medusa, logit):
2727
def main(args):
2828
model = MedusaModel.from_pretrained(
2929
args.model_path,
30-
medusa_num_heads=args.medusa_num_heads,
30+
# medusa_num_heads=args.medusa_num_heads,
3131
torch_dtype=torch.float16,
3232
low_cpu_mem_usage=True,
3333
device_map="auto"
@@ -58,15 +58,15 @@ def main(args):
5858
model.current_length_data.zero_() # this is for rerun
5959
reset_medusa_mode(model)
6060
medusa_logits, outputs, logits = model(
61-
input_ids, past_key_values=past_key_values, output_orig=True
61+
input_ids, past_key_values=past_key_values, output_orig=True, medusa_forward=True
6262
)
6363
_, medusa_topk = medusa_logits[...,-1,:].topk(20, dim=-1)
6464
input_id = logits[:, -1:].argmax(dim=-1)
6565
logits_ids.append(input_id.detach().cpu())
6666
medusa_topk_ids.append(medusa_topk.detach().cpu())
6767
for _ in range(steps):
6868
medusa_logits, outputs, logits = model(
69-
input_id, past_key_values=past_key_values, output_orig=True
69+
input_id, past_key_values=past_key_values, output_orig=True, medusa_forward=True
7070
)
7171
_, medusa_topk = medusa_logits[...,-1,:].topk(20, dim=-1)
7272
input_id = logits[:, -1:].argmax(dim=-1)

0 commit comments

Comments
 (0)