Skip to content

Commit 95f1271

Browse files
committed
medusa choice auto dispatch
1 parent a4ec58e commit 95f1271

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

medusa/model/medusa_model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from transformers import PreTrainedModel, PretrainedConfig
1212
from .utils import *
1313
from .kv_cache import initialize_past_key_values
14-
from .medusa_choices import mc_sim_7b_63
14+
from .medusa_choices import *
1515
from transformers import AutoTokenizer, AutoConfig
1616
import os
1717
from huggingface_hub import hf_hub_download
18+
import warnings
1819

1920
class MedusaConfig(PretrainedConfig):
2021
"""
@@ -218,6 +219,18 @@ def forward(
218219
if output_orig:
219220
return torch.stack(medusa_logits, dim=0), outputs, orig
220221
return torch.stack(medusa_logits, dim=0)
222+
def get_medusa_choice(self, model_name):
223+
if 'vicuna' in model_name:
224+
if '7b' in model_name:
225+
return vicuna_7b_stage2
226+
elif '13b' in model_name:
227+
return vicuna_13b_stage2
228+
elif '33b' in model_name:
229+
return vicuna_33b_stage2
230+
elif 'zephyr' in model_name:
231+
return zephyr_stage2
232+
warnings.warn('Please specify medusa choice configuration!')
233+
return mc_sim_7b_63
221234

222235
def medusa_generate(
223236
self,
@@ -227,7 +240,7 @@ def medusa_generate(
227240
max_steps=512,
228241
# The hyperparameters below are for the Medusa
229242
# top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
230-
medusa_choices=mc_sim_7b_63,
243+
medusa_choices=None,
231244
posterior_threshold=0.09, # threshold validation of Medusa output
232245
# another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
233246
posterior_alpha=0.3,
@@ -256,6 +269,9 @@ def medusa_generate(
256269
input_ids = input_ids.clone()
257270

258271
# Cache medusa buffers (the fixed patterns for tree attention)
272+
if medusa_choices is None:
273+
medusa_choices = self.get_medusa_choice(self.base_model_name_or_path)
274+
259275
if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices:
260276
# Load the cached medusa buffer
261277
medusa_buffers = self.medusa_buffers

0 commit comments

Comments
 (0)