22import torch .nn as nn
33from transformers import PreTrainedModel , PretrainedConfig
44from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
5+ from transformers import AutoTokenizer
56from .utils import *
67from .kv_cache import initialize_past_key_values
7- from .medusa_choices import mc_sim_7b_63
8- from transformers import AutoTokenizer
98import os
109from huggingface_hub import hf_hub_download
1110
1211
1312class MedusaConfig (PretrainedConfig ):
1413 def __init__ (
1514 self ,
16- medusa_num_heads = 4 ,
15+ medusa_num_heads = 2 ,
1716 medusa_num_layers = 1 ,
1817 base_model_name_or_path = "lmsys/vicuna-7b-v1.3" ,
1918 ** kwargs ,
@@ -111,7 +110,6 @@ def get_tokenizer(self):
111110 def from_pretrained (
112111 cls ,
113112 medusa_head_name_or_path ,
114- medusa_num_heads = None ,
115113 ** kwargs ,
116114 ):
117115 """
@@ -123,12 +121,9 @@ def from_pretrained(
123121 MedusaModel: A MedusaModel instance loaded from the given path.
124122 """
125123 medusa_config = MedusaConfig .from_pretrained (medusa_head_name_or_path )
126- if medusa_num_heads is not None :
127- medusa_config .medusa_num_heads = medusa_num_heads
128124 base_model = KVLlamaForCausalLM .from_pretrained (
129125 medusa_config .base_model_name_or_path , ** kwargs
130126 )
131-
132127 model = cls (
133128 base_model ,
134129 medusa_config .medusa_num_heads ,
@@ -196,7 +191,7 @@ def medusa_generate(
196191 max_steps = 512 ,
197192 # The hyperparameters below are for the Medusa
198193 # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
199- medusa_choices = mc_sim_7b_63 ,
194+ medusa_choices = [ 1 , 7 , 6 ] ,
200195 posterior_threshold = 0.09 , # threshold validation of Medusa output
201196 # another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
202197 posterior_alpha = 0.3 ,
@@ -230,6 +225,7 @@ def medusa_generate(
230225 self .medusa_buffers = medusa_buffers
231226 self .medusa_choices = medusa_choices
232227
228+ medusa_topk = medusa_choices [1 :]
233229
234230 # Initialize the past key and value states
235231 if hasattr (self , "past_key_values" ):
@@ -264,8 +260,9 @@ def medusa_generate(
264260 candidates , tree_candidates = generate_candidates (
265261 medusa_logits ,
266262 logits ,
263+ medusa_topk ,
267264 medusa_buffers ["tree_indices" ],
268- medusa_buffers [ "retrieve_indices" ] ,
265+ temperature ,
269266 )
270267
271268 # Use tree attention to verify the candidates and get predictions
0 commit comments