1111from transformers import PreTrainedModel , PretrainedConfig
1212from .utils import *
1313from .kv_cache import initialize_past_key_values
14- from .medusa_choices import mc_sim_7b_63
14+ from .medusa_choices import *
1515from transformers import AutoTokenizer , AutoConfig
1616import os
1717from huggingface_hub import hf_hub_download
18+ import warnings
1819
1920class MedusaConfig (PretrainedConfig ):
2021 """
@@ -218,6 +219,18 @@ def forward(
218219 if output_orig :
219220 return torch .stack (medusa_logits , dim = 0 ), outputs , orig
220221 return torch .stack (medusa_logits , dim = 0 )
222+ def get_medusa_choice (self , model_name ):
223+ if 'vicuna' in model_name :
224+ if '7b' in model_name :
225+ return vicuna_7b_stage2
226+ elif '13b' in model_name :
227+ return vicuna_13b_stage2
228+ elif '33b' in model_name :
229+ return vicuna_33b_stage2
230+ elif 'zephyr' in model_name :
231+ return zephyr_stage2
232+ warnings .warn ('Please specify medusa choice configuration!' )
233+ return mc_sim_7b_63
221234
222235 def medusa_generate (
223236 self ,
@@ -227,7 +240,7 @@ def medusa_generate(
227240 max_steps = 512 ,
228241 # The hyperparameters below are for the Medusa
229242 # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
230- medusa_choices = mc_sim_7b_63 ,
243+ medusa_choices = None ,
231244 posterior_threshold = 0.09 , # threshold validation of Medusa output
232245 # another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
233246 posterior_alpha = 0.3 ,
@@ -256,6 +269,9 @@ def medusa_generate(
256269 input_ids = input_ids .clone ()
257270
258271 # Cache medusa buffers (the fixed patterns for tree attention)
272+ if medusa_choices is None :
273+ medusa_choices = self .get_medusa_choice (self .base_model_name_or_path )
274+
259275 if hasattr (self , "medusa_choices" ) and self .medusa_choices == medusa_choices :
260276 # Load the cached medusa buffer
261277 medusa_buffers = self .medusa_buffers
0 commit comments