66from .utils import *
77from .kv_cache import initialize_past_key_values
88from .medusa_choices import mc_sim_7b_63
9- from transformers import AutoTokenizer
9+ from transformers import AutoTokenizer , AutoConfig
1010import os
1111from huggingface_hub import hf_hub_download
1212
1313
14- class MedusaConfig (PretrainedConfig ):
15- """
16- Configuration class for Medusa model.
17-
18- Args:
19- medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
20- medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
21- base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
22- **kwargs: Additional keyword arguments to be passed to the parent class constructor.
23- """
24-
25- def __init__ (
26- self ,
27- medusa_num_heads = 4 ,
28- medusa_num_layers = 1 ,
29- base_model_name_or_path = "lmsys/vicuna-7b-v1.3" ,
30- ** kwargs ,
31- ):
32- super ().__init__ (** kwargs )
33- self .medusa_num_heads = medusa_num_heads
34- self .medusa_num_layers = medusa_num_layers
35- self .base_model_name_or_path = base_model_name_or_path
36-
37-
3814class ResBlock (nn .Module ):
3915 """
4016 A Residual Block module.
@@ -66,8 +42,7 @@ def forward(self, x):
6642 """
6743 return x + self .act (self .linear (x ))
6844
69-
70- class MedusaModel (nn .Module ):
45+ class MedusaLlamaModel (KVLlamaForCausalLM ):
7146 """The Medusa Language Model Head.
7247
7348 This module creates a series of prediction heads (based on the 'medusa' parameter)
@@ -77,22 +52,21 @@ class MedusaModel(nn.Module):
7752
7853 def __init__ (
7954 self ,
80- base_model ,
81- medusa_num_heads = 4 ,
82- medusa_num_layers = 1 ,
83- base_model_name_or_path = "lmsys/vicuna-7b-v1.3" ,
55+ config ,
8456 ):
8557 """
8658 Args:
87- base_model (nn.Module): The base language model to be used.
88- medusa_num_heads (int, optional): Number of additional tokens to predict. Defaults to 3.
89- medusa_num_layers (int, optional): Number of ResBlock layers for each Medusa head. Defaults to 0.
90- """
91- super ().__init__ ()
92- self .base_model = base_model
93- self .config = base_model .config
94- self .hidden_size = base_model .lm_head .weight .shape [- 1 ]
95- self .vocab_size = base_model .lm_head .weight .shape [0 ]
59+ config (PretrainedConfig): The configuration of the MedusaModel.
60+ """
61+ # Load the base model
62+ super ().__init__ (config )
63+ # For compatibility with the old APIs
64+
65+ medusa_num_heads = config .medusa_num_heads
66+ medusa_num_layers = config .medusa_num_layers
67+ base_model_name_or_path = config ._name_or_path
68+ self .hidden_size = config .hidden_size
69+ self .vocab_size = config .vocab_size
9670 self .medusa = medusa_num_heads
9771 self .medusa_num_layers = medusa_num_layers
9872 self .base_model_name_or_path = base_model_name_or_path
@@ -108,79 +82,44 @@ def __init__(
10882 ]
10983 )
11084
111- # Ensure medusa_head's dtype and device align with the base_model
112- self .medusa_head .to (self .base_model .dtype ).to (self .base_model .device )
113-
114- for i in range (medusa_num_heads ):
115- # Initialize the weights of each medusa_head using the base model's weights
116- self .medusa_head [i ][- 1 ].weight .data [:] = base_model .lm_head .weight .data [:]
117-
85+ # Add a link named base_model to self
86+ @property
87+ def base_model (self ):
88+ return self
89+
11890 def get_tokenizer (self ):
11991 """Get the tokenizer of the base model.
12092
12193 Returns:
12294 Tokenizer: The tokenizer of the base model.
12395 """
12496 return self .tokenizer
125-
97+
12698 @classmethod
12799 def from_pretrained (
128100 cls ,
129- medusa_head_name_or_path ,
130- base_model = None ,
131- medusa_num_heads = None ,
101+ pretrained_model_name_or_path ,
102+ * args ,
132103 ** kwargs ,
133104 ):
134- """
135- Args:
136- medusa_head_name_or_path (str): Name or path of the Medusa head to load.
137- **kwargs: Additional keyword arguments for loading the base model.
138-
139- Returns:
140- MedusaModel: A MedusaModel instance loaded from the given path.
141- """
142- medusa_config = MedusaConfig .from_pretrained (medusa_head_name_or_path )
143- if medusa_num_heads is not None :
144- print ("Overriding medusa_num_heads as:" , medusa_num_heads )
145- medusa_config .medusa_num_heads = medusa_num_heads
146- if base_model is not None :
147- print ("Overriding base_model as:" , base_model )
148- medusa_config .base_model_name_or_path = base_model
149-
150- # hardcode
151- if 'mistral' or 'zephyr' in medusa_config .base_model_name_or_path :
152- base_model = KVMistralForCausalLM .from_pretrained (
153- medusa_config .base_model_name_or_path , ** kwargs
154- )
155- else :
156- base_model = KVLlamaForCausalLM .from_pretrained (
157- medusa_config .base_model_name_or_path , ** kwargs
158- )
159-
160- model = cls (
161- base_model ,
162- medusa_config .medusa_num_heads ,
163- medusa_config .medusa_num_layers ,
164- medusa_config .base_model_name_or_path ,
105+ # Manually load config to ensure that the medusa_num_heads parameter is loaded
106+ config = AutoConfig .from_pretrained (pretrained_model_name_or_path )
107+ return super ().from_pretrained (
108+ pretrained_model_name_or_path ,
109+ * args ,
110+ ** kwargs ,
111+ config = config ,
165112 )
166- medusa_head_path = os .path .join (medusa_head_name_or_path , "medusa_lm_head.pt" )
167- if os .path .exists (medusa_head_path ):
168- filename = medusa_head_path
169- else :
170- filename = hf_hub_download (medusa_head_name_or_path , "medusa_lm_head.pt" )
171- medusa_head_state_dict = torch .load (filename , map_location = base_model .device )
172- model .medusa_head .load_state_dict (medusa_head_state_dict , strict = False )
173-
174- return model
175113
176114 def forward (
177115 self ,
178116 input_ids = None ,
179117 attention_mask = None ,
180- labels = None ,
181118 past_key_values = None ,
182119 output_orig = False ,
183120 position_ids = None ,
121+ medusa_forward = False ,
122+ ** kwargs ,
184123 ):
185124 """Forward pass of the MedusaModel.
186125
@@ -196,13 +135,22 @@ def forward(
196135 torch.Tensor: A tensor containing predictions from all Medusa heads.
197136 (Optional) Original predictions from the base model's LM head.
198137 """
138+ if not medusa_forward :
139+ return super ().forward (
140+ input_ids = input_ids ,
141+ attention_mask = attention_mask ,
142+ past_key_values = past_key_values ,
143+ position_ids = position_ids ,
144+ ** kwargs ,
145+ )
199146 with torch .inference_mode ():
200147 # Pass input through the base model
201148 outputs = self .base_model .model (
202149 input_ids = input_ids ,
203150 attention_mask = attention_mask ,
204151 past_key_values = past_key_values ,
205152 position_ids = position_ids ,
153+ ** kwargs ,
206154 )
207155 if output_orig :
208156 orig = self .base_model .lm_head (outputs [0 ])
@@ -337,3 +285,6 @@ def medusa_generate(
337285
338286 if self .tokenizer .eos_token_id in input_ids [0 , input_len :]:
339287 break
288+
289+ # Currently only support LlamaModel
290+ MedusaModel = MedusaLlamaModel
0 commit comments