2424from medusa .model .kv_cache import initialize_past_key_values
2525from medusa .model .medusa_choices import *
2626
27- def medusa_forward (input_ids , model , tokenizer , medusa_choices , temperature , posterior_threshold , posterior_alpha , max_steps = 512 ):
27+ def medusa_forward (input_ids , model , tokenizer , medusa_choices , temperature , posterior_threshold , posterior_alpha , top_p = 0.8 , sampling = 'typical' , fast = True , max_steps = 512 ):
2828 assert input_ids .shape [0 ] == 1 , "Only support batch size 1 for now!!"
2929 # Avoid modifying the input_ids in-place
3030 input_ids = input_ids .clone ()
@@ -71,6 +71,7 @@ def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, pos
7171 logits ,
7272 medusa_buffers ["tree_indices" ],
7373 medusa_buffers ["retrieve_indices" ],
74+ temperature , posterior_threshold , posterior_alpha , top_p , sampling , fast
7475 )
7576 medusa_logits , logits , outputs = tree_decoding (
7677 model ,
@@ -81,7 +82,7 @@ def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, pos
8182 medusa_buffers ["retrieve_indices" ],
8283 )
8384 best_candidate , accept_length = evaluate_posterior (
84- logits , candidates , temperature , posterior_threshold , posterior_alpha
85+ logits , candidates , temperature , posterior_threshold , posterior_alpha , top_p , sampling , fast
8586 )
8687 input_ids , logits , medusa_logits , new_token = update_inference_inputs (
8788 input_ids ,
@@ -117,6 +118,9 @@ def run_eval(
117118 temperature ,
118119 posterior_threshold ,
119120 posterior_alpha ,
121+ top_p ,
122+ sampling ,
123+ fast ,
120124 medusa_choices ,
121125):
122126 questions = load_questions (question_file , question_begin , question_end )
@@ -153,6 +157,9 @@ def run_eval(
153157 temperature ,
154158 posterior_threshold ,
155159 posterior_alpha ,
160+ sampling ,
161+ top_p ,
162+ fast ,
156163 medusa_choices ,
157164 )
158165 )
@@ -174,15 +181,22 @@ def get_model_answers(
174181 temperature ,
175182 posterior_threshold ,
176183 posterior_alpha ,
184+ sampling ,
185+ top_p ,
186+ fast ,
177187 medusa_choices ,
178188):
179189
180190 # Medusa model setup
181- num_heads = 4
191+
192+ num_heads = - 1
193+ for choice in medusa_choices :
194+ if len (choice ) > num_heads :
195+ num_heads = len (choice )
182196
183197 model = MedusaModel .from_pretrained (
184198 model_path ,
185- medusa_num_heads = num_heads ,
199+ # medusa_num_heads = num_heads,
186200 torch_dtype = torch .float16 ,
187201 low_cpu_mem_usage = True ,
188202 device_map = "auto"
@@ -200,7 +214,7 @@ def get_model_answers(
200214
201215 # warmup
202216 for _ in range (3 ):
203- torch .manual_seed (0 )
217+ # torch.manual_seed(0)
204218 conv = get_conversation_template (model_id )
205219 turns = []
206220 idxs = []
@@ -227,9 +241,12 @@ def get_model_answers(
227241 model ,
228242 tokenizer ,
229243 medusa_choices ,
230- temperature ,
244+ 0.7 ,
231245 posterior_threshold ,
232246 posterior_alpha ,
247+ top_p = top_p ,
248+ sampling = sampling ,
249+ fast = fast ,
233250 )
234251 torch .cuda .synchronize ()
235252 total_time = time .time () - start_time
@@ -261,6 +278,7 @@ def get_model_answers(
261278 if conv .name == "xgen" and output .startswith ("Assistant:" ):
262279 output = output .replace ("Assistant:" , "" , 1 ).strip ()
263280 except RuntimeError as e :
281+ print (e )
264282 print ("ERROR question ID: " , question ["question_id" ])
265283 output = "ERROR"
266284
@@ -280,7 +298,7 @@ def get_model_answers(
280298
281299 choices = []
282300 for i in range (num_choices ):
283- torch .manual_seed (i )
301+ # torch.manual_seed(i)
284302 conv = get_conversation_template (model_id )
285303 turns = []
286304 idxs = []
@@ -310,6 +328,9 @@ def get_model_answers(
310328 temperature ,
311329 posterior_threshold ,
312330 posterior_alpha ,
331+ top_p = top_p ,
332+ sampling = sampling ,
333+ fast = fast ,
313334 )
314335 torch .cuda .synchronize ()
315336 total_time = time .time () - start_time
@@ -456,19 +477,39 @@ def reorg_answer_file(answer_file):
456477 help = "The posterior alpha for medusa sampling." ,
457478 )
458479
480+ parser .add_argument (
481+ "--top-p" ,
482+ type = float ,
483+ default = 0.8 ,
484+ help = "The top-p for medusa sampling." ,
485+ )
486+
487+ parser .add_argument (
488+ "--sampling" ,
489+ type = str ,
490+ default = "typical" ,
491+ help = "The sampling method for medusa sampling." ,
492+ )
493+
494+ parser .add_argument (
495+ "--fast" ,
496+ action = "store_true" ,
497+ help = "Whether to use fast decoding." ,
498+ )
499+
459500 parser .add_argument (
460501 "--medusa-choices" ,
461502 type = str ,
462503 default = "mc_sim_7b_63" ,
463504 help = "The medusa choices for medusa sampling." ,
464505 )
465506
466-
507+
467508
468509
469510 args = parser .parse_args ()
470511
471- args .model_id = args .model_id + "-temperature-" + str (args .temperature )+ "-posterior_threshold-" + str (args .posterior_threshold )+ "-posterior_alpha-" + str (args .posterior_alpha )
512+ args .model_id = args .model_id + "-temperature-" + str (args .temperature )+ "-posterior_threshold-" + str (args .posterior_threshold )+ "-posterior_alpha-" + str (args .posterior_alpha )+ "-top_p-" + str ( args . top_p ) + "-sampling-" + args . sampling + "-fast-" + str ( args . fast )
472513 args .medusa_choices = eval (args .medusa_choices )
473514 if args .num_gpus_total // args .num_gpus_per_model > 1 :
474515 import ray
@@ -499,6 +540,9 @@ def reorg_answer_file(answer_file):
499540 args .temperature ,
500541 args .posterior_threshold ,
501542 args .posterior_alpha ,
543+ args .top_p ,
544+ args .sampling ,
545+ args .fast ,
502546 args .medusa_choices ,
503547 )
504548
0 commit comments