1+ import os
2+ import torch
3+ import json
4+ from contextlib import contextmanager
5+ import numpy as np
6+ from medusa .model .medusa_model import MedusaModel
7+ from medusa .model .kv_cache import *
8+ from medusa .model .utils import *
9+ from medusa .model .medusa_choices import *
10+ from copy import deepcopy
11+ import matplotlib .pyplot as plt
12+ import torch .nn .functional as F
13+ from fastchat .model .model_adapter import get_conversation_template
14+ from tqdm import tqdm
15+ import argparse
16+
17+ def get_accuracies (medusa , logit ):
18+ # get the correct counts of each head
19+ seq_len , choices , topk = medusa .shape
20+ results = []
21+ for choice in range (choices ):
22+ results .append (medusa [:- choice - 1 ,choice ].eq (logit [choice + 1 :,0 ]))
23+ return results
24+
25+
26+
27+ def main (args ):
28+ model = MedusaModel .from_pretrained (
29+ args .model_path ,
30+ medusa_num_heads = args .medusa_num_heads ,
31+ torch_dtype = torch .float16 ,
32+ low_cpu_mem_usage = True ,
33+ device_map = "auto"
34+ )
35+ tokenizer = model .get_tokenizer ()
36+
37+
38+ data = json .load (open (args .data_path ))
39+ past_key_values , past_key_values_data , current_length_data = initialize_past_key_values (model .base_model , model .medusa_num_decoder_layers )
40+ model .past_key_values = past_key_values
41+ model .past_key_values_data = past_key_values_data
42+ model .current_length_data = current_length_data
43+ results = None
44+
45+ for sample in tqdm ((data )):
46+ conv = get_conversation_template ("vicuna" )
47+ conv .messages = []
48+ conv .append_message (conv .roles [0 ], sample ["instruction" ])
49+ conv .append_message (conv .roles [1 ], "" )
50+ prompt = conv .get_prompt ()
51+ steps = args .steps
52+ logits_ids = []
53+ medusa_topk_ids = []
54+
55+ with torch .inference_mode ():
56+ input_ids = tokenizer ([prompt ]).input_ids
57+ input_ids = torch .as_tensor (input_ids ).cuda ()
58+ model .current_length_data .zero_ () # this is for rerun
59+ reset_medusa_mode (model )
60+ medusa_logits , outputs , logits = model (
61+ input_ids , past_key_values = past_key_values , output_orig = True
62+ )
63+ _ , medusa_topk = medusa_logits [...,- 1 ,:].topk (20 , dim = - 1 )
64+ input_id = logits [:, - 1 :].argmax (dim = - 1 )
65+ logits_ids .append (input_id .detach ().cpu ())
66+ medusa_topk_ids .append (medusa_topk .detach ().cpu ())
67+ for _ in range (steps ):
68+ medusa_logits , outputs , logits = model (
69+ input_id , past_key_values = past_key_values , output_orig = True
70+ )
71+ _ , medusa_topk = medusa_logits [...,- 1 ,:].topk (20 , dim = - 1 )
72+ input_id = logits [:, - 1 :].argmax (dim = - 1 )
73+ logits_ids .append (input_id .detach ().cpu ())
74+ medusa_topk_ids .append (medusa_topk .detach ().cpu ())
75+ logits_ids = torch .stack (logits_ids , dim = 0 )
76+ medusa_topk_ids = torch .stack (medusa_topk_ids , dim = 0 ).squeeze (2 )
77+ if results is None :
78+ results = get_accuracies (medusa_topk_ids , logits_ids )
79+ else :
80+ # cat sub results
81+ cur_results = get_accuracies (medusa_topk_ids , logits_ids )
82+ for i in range (len (results )):
83+ results [i ] = torch .cat ((results [i ], cur_results [i ]), dim = 0 )
84+
85+ save_path = os .path .join (args .save_dir , args .model_name + "_heads_accuracy.pt" )
86+ torch .save (results , save_path )
87+
88+ if __name__ == "__main__" :
89+ parser = argparse .ArgumentParser (description = "Medusa Model Evaluator" )
90+
91+ parser .add_argument ("--model_path" , type = str , required = True ,
92+ help = "Path to the pre-trained Medusa model." )
93+ parser .add_argument ("--model_name" , type = str , required = True ,
94+ help = "Name of the model." )
95+ parser .add_argument ("--medusa_num_heads" , type = int , default = 5 ,
96+ help = "Number of medusa heads." )
97+ parser .add_argument ("--data_path" , type = str , required = True ,
98+ help = "Path to the evaluation data in JSON format." )
99+ parser .add_argument ("--save_dir" , type = str , default = "../../data" ,
100+ help = "Directory to save the results." )
101+ parser .add_argument ("--steps" , type = int , default = 20 ,
102+ help = "Number of steps to run the model." )
103+ args = parser .parse_args ()
104+
105+ # If the save directory doesn't exist, create it
106+ if not os .path .exists (args .save_dir ):
107+ os .makedirs (args .save_dir )
108+ main (args )
0 commit comments