2222from medusa .model .utils import *
2323from medusa .model .medusa_model import MedusaModel
2424from medusa .model .kv_cache import initialize_past_key_values
25- from medusa .model .medusa_choices import medusa_choices
25+ from medusa .model .medusa_choices import *
2626
2727def medusa_forward (input_ids , model , tokenizer , medusa_choices , temperature , posterior_threshold , posterior_alpha , max_steps = 512 ):
2828 assert input_ids .shape [0 ] == 1 , "Only support batch size 1 for now!!"
@@ -191,7 +191,7 @@ def get_model_answers(
191191 tokenizer = model .get_tokenizer ()
192192
193193 model .eval ()
194- print ('Check model state:' ,model .training )
194+ print ('Check model training state:' ,model .training )
195195
196196 cuda_visible_devices = os .environ .get ('CUDA_VISIBLE_DEVICES' )
197197 print ('CUDA VISIBLE DEVICES:' , cuda_visible_devices )
@@ -456,14 +456,20 @@ def reorg_answer_file(answer_file):
456456 help = "The posterior alpha for medusa sampling." ,
457457 )
458458
459+ parser .add_argument (
460+ "--medusa-choices" ,
461+ type = str ,
462+ default = "mc_sim_7b_63" ,
463+ help = "The medusa choices for medusa sampling." ,
464+ )
459465
460466
461467
462468
463469 args = parser .parse_args ()
464470
465471 args .model_id = args .model_id + "-temperature-" + str (args .temperature )+ "-posterior_threshold-" + str (args .posterior_threshold )+ "-posterior_alpha-" + str (args .posterior_alpha )
466-
472+ args . medusa_choices = eval ( args . medusa_choices )
467473 if args .num_gpus_total // args .num_gpus_per_model > 1 :
468474 import ray
469475
@@ -493,7 +499,7 @@ def reorg_answer_file(answer_file):
493499 args .temperature ,
494500 args .posterior_threshold ,
495501 args .posterior_alpha ,
496- medusa_choices ,
502+ args . medusa_choices ,
497503 )
498504
499505 reorg_answer_file (answer_file )
0 commit comments