@@ -27,7 +27,7 @@ def get_accuracies(medusa, logit):
2727def main (args ):
2828 model = MedusaModel .from_pretrained (
2929 args .model_path ,
30- medusa_num_heads = args .medusa_num_heads ,
30+ # medusa_num_heads=args.medusa_num_heads,
3131 torch_dtype = torch .float16 ,
3232 low_cpu_mem_usage = True ,
3333 device_map = "auto"
@@ -58,15 +58,15 @@ def main(args):
5858 model .current_length_data .zero_ () # this is for rerun
5959 reset_medusa_mode (model )
6060 medusa_logits , outputs , logits = model (
61- input_ids , past_key_values = past_key_values , output_orig = True
61+ input_ids , past_key_values = past_key_values , output_orig = True , medusa_forward = True
6262 )
6363 _ , medusa_topk = medusa_logits [...,- 1 ,:].topk (20 , dim = - 1 )
6464 input_id = logits [:, - 1 :].argmax (dim = - 1 )
6565 logits_ids .append (input_id .detach ().cpu ())
6666 medusa_topk_ids .append (medusa_topk .detach ().cpu ())
6767 for _ in range (steps ):
6868 medusa_logits , outputs , logits = model (
69- input_id , past_key_values = past_key_values , output_orig = True
69+ input_id , past_key_values = past_key_values , output_orig = True , medusa_forward = True
7070 )
7171 _ , medusa_topk = medusa_logits [...,- 1 ,:].topk (20 , dim = - 1 )
7272 input_id = logits [:, - 1 :].argmax (dim = - 1 )
0 commit comments