55from .utils import *
66from .kv_cache import initialize_past_key_values
77from .medusa_choices import mc_sim_7b_63
8- from transformers import AutoTokenizer
8+ from transformers import AutoTokenizer , AutoConfig
99import os
1010from huggingface_hub import hf_hub_download
1111
1212
13- class MedusaConfig (PretrainedConfig ):
14- """
15- Configuration class for Medusa model.
16-
17- Args:
18- medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
19- medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
20- base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
21- **kwargs: Additional keyword arguments to be passed to the parent class constructor.
22- """
23-
24- def __init__ (
25- self ,
26- medusa_num_heads = 4 ,
27- medusa_num_layers = 1 ,
28- base_model_name_or_path = "lmsys/vicuna-7b-v1.3" ,
29- ** kwargs ,
30- ):
31- super ().__init__ (** kwargs )
32- self .medusa_num_heads = medusa_num_heads
33- self .medusa_num_layers = medusa_num_layers
34- self .base_model_name_or_path = base_model_name_or_path
35-
36-
3713class ResBlock (nn .Module ):
3814 """
3915 A Residual Block module.
@@ -65,8 +41,7 @@ def forward(self, x):
6541 """
6642 return x + self .act (self .linear (x ))
6743
68-
69- class MedusaModel (nn .Module ):
44+ class MedusaLlamaModel (KVLlamaForCausalLM ):
7045 """The Medusa Language Model Head.
7146
7247 This module creates a series of prediction heads (based on the 'medusa' parameter)
@@ -76,22 +51,21 @@ class MedusaModel(nn.Module):
7651
7752 def __init__ (
7853 self ,
79- base_model ,
80- medusa_num_heads = 4 ,
81- medusa_num_layers = 1 ,
82- base_model_name_or_path = "lmsys/vicuna-7b-v1.3" ,
54+ config ,
8355 ):
8456 """
8557 Args:
86- base_model (nn.Module): The base language model to be used.
87- medusa_num_heads (int, optional): Number of additional tokens to predict. Defaults to 3.
88- medusa_num_layers (int, optional): Number of ResBlock layers for each Medusa head. Defaults to 0.
89- """
90- super ().__init__ ()
91- self .base_model = base_model
92- self .config = base_model .config
93- self .hidden_size = base_model .lm_head .weight .shape [- 1 ]
94- self .vocab_size = base_model .lm_head .weight .shape [0 ]
58+ config (PretrainedConfig): The configuration of the MedusaModel.
59+ """
60+ # Load the base model
61+ super ().__init__ (config )
62+ # For compatibility with the old APIs
63+
64+ medusa_num_heads = config .medusa_num_heads
65+ medusa_num_layers = config .medusa_num_layers
66+ base_model_name_or_path = config ._name_or_path
67+ self .hidden_size = config .hidden_size
68+ self .vocab_size = config .vocab_size
9569 self .medusa = medusa_num_heads
9670 self .medusa_num_layers = medusa_num_layers
9771 self .base_model_name_or_path = base_model_name_or_path
@@ -107,73 +81,44 @@ def __init__(
10781 ]
10882 )
10983
110- # Ensure medusa_head's dtype and device align with the base_model
111- self .medusa_head .to (self .base_model .dtype ).to (self .base_model .device )
112-
113- for i in range (medusa_num_heads ):
114- # Initialize the weights of each medusa_head using the base model's weights
115- self .medusa_head [i ][- 1 ].weight .data [:] = base_model .lm_head .weight .data [:]
116-
84+ # Add a link named base_model to self
85+ @property
86+ def base_model (self ):
87+ return self
88+
11789 def get_tokenizer (self ):
11890 """Get the tokenizer of the base model.
11991
12092 Returns:
12193 Tokenizer: The tokenizer of the base model.
12294 """
12395 return self .tokenizer
124-
96+
12597 @classmethod
12698 def from_pretrained (
12799 cls ,
128- medusa_head_name_or_path ,
129- base_model = None ,
130- medusa_num_heads = None ,
100+ pretrained_model_name_or_path ,
101+ * args ,
131102 ** kwargs ,
132103 ):
133- """
134- Args:
135- medusa_head_name_or_path (str): Name or path of the Medusa head to load.
136- **kwargs: Additional keyword arguments for loading the base model.
137-
138- Returns:
139- MedusaModel: A MedusaModel instance loaded from the given path.
140- """
141- medusa_config = MedusaConfig .from_pretrained (medusa_head_name_or_path )
142- if medusa_num_heads is not None :
143- print ("Overriding medusa_num_heads as:" , medusa_num_heads )
144- medusa_config .medusa_num_heads = medusa_num_heads
145- if base_model is not None :
146- print ("Overriding base_model as:" , base_model )
147- medusa_config .base_model_name_or_path = base_model
148-
149- base_model = KVLlamaForCausalLM .from_pretrained (
150- medusa_config .base_model_name_or_path , ** kwargs
151- )
152-
153- model = cls (
154- base_model ,
155- medusa_config .medusa_num_heads ,
156- medusa_config .medusa_num_layers ,
157- medusa_config .base_model_name_or_path ,
104+ # Manually load config to ensure that the medusa_num_heads parameter is loaded
105+ config = AutoConfig .from_pretrained (pretrained_model_name_or_path )
106+ return super ().from_pretrained (
107+ pretrained_model_name_or_path ,
108+ * args ,
109+ ** kwargs ,
110+ config = config ,
158111 )
159- medusa_head_path = os .path .join (medusa_head_name_or_path , "medusa_lm_head.pt" )
160- if os .path .exists (medusa_head_path ):
161- filename = medusa_head_path
162- else :
163- filename = hf_hub_download (medusa_head_name_or_path , "medusa_lm_head.pt" )
164- medusa_head_state_dict = torch .load (filename , map_location = base_model .device )
165- model .medusa_head .load_state_dict (medusa_head_state_dict , strict = False )
166-
167- return model
168112
169113 def forward (
170114 self ,
171115 input_ids = None ,
172116 attention_mask = None ,
173- labels = None ,
174117 past_key_values = None ,
175118 output_orig = False ,
176119 position_ids = None ,
120+ medusa_forward = False ,
121+ ** kwargs ,
177122 ):
178123 """Forward pass of the MedusaModel.
179124
@@ -189,13 +134,22 @@ def forward(
189134 torch.Tensor: A tensor containing predictions from all Medusa heads.
190135 (Optional) Original predictions from the base model's LM head.
191136 """
137+ if not medusa_forward :
138+ return super ().forward (
139+ input_ids = input_ids ,
140+ attention_mask = attention_mask ,
141+ past_key_values = past_key_values ,
142+ position_ids = position_ids ,
143+ ** kwargs ,
144+ )
192145 with torch .inference_mode ():
193146 # Pass input through the base model
194147 outputs = self .base_model .model (
195148 input_ids = input_ids ,
196149 attention_mask = attention_mask ,
197150 past_key_values = past_key_values ,
198151 position_ids = position_ids ,
152+ ** kwargs ,
199153 )
200154 if output_orig :
201155 orig = self .base_model .lm_head (outputs [0 ])
@@ -330,3 +284,6 @@ def medusa_generate(
330284
331285 if self .tokenizer .eos_token_id in input_ids [0 , input_len :]:
332286 break
287+
288+ # Currently only support LlamaModel
289+ MedusaModel = MedusaLlamaModel
0 commit comments