Skip to content

Commit a357895

Browse files
committed
Merge branch 'v1.0-prerelease' of github.com:FasterDecoding/Medusa into v1.0-prerelease
2 parents af8c7d9 + 46f911e commit a357895

File tree

5 files changed

+379
-95
lines changed

5 files changed

+379
-95
lines changed

medusa/model/medusa_model.py

Lines changed: 43 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,11 @@
66
from .utils import *
77
from .kv_cache import initialize_past_key_values
88
from .medusa_choices import mc_sim_7b_63
9-
from transformers import AutoTokenizer
9+
from transformers import AutoTokenizer, AutoConfig
1010
import os
1111
from huggingface_hub import hf_hub_download
1212

1313

14-
class MedusaConfig(PretrainedConfig):
15-
"""
16-
Configuration class for Medusa model.
17-
18-
Args:
19-
medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
20-
medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
21-
base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
22-
**kwargs: Additional keyword arguments to be passed to the parent class constructor.
23-
"""
24-
25-
def __init__(
26-
self,
27-
medusa_num_heads=4,
28-
medusa_num_layers=1,
29-
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
30-
**kwargs,
31-
):
32-
super().__init__(**kwargs)
33-
self.medusa_num_heads = medusa_num_heads
34-
self.medusa_num_layers = medusa_num_layers
35-
self.base_model_name_or_path = base_model_name_or_path
36-
37-
3814
class ResBlock(nn.Module):
3915
"""
4016
A Residual Block module.
@@ -66,8 +42,7 @@ def forward(self, x):
6642
"""
6743
return x + self.act(self.linear(x))
6844

69-
70-
class MedusaModel(nn.Module):
45+
class MedusaLlamaModel(KVLlamaForCausalLM):
7146
"""The Medusa Language Model Head.
7247
7348
This module creates a series of prediction heads (based on the 'medusa' parameter)
@@ -77,22 +52,21 @@ class MedusaModel(nn.Module):
7752

7853
def __init__(
7954
self,
80-
base_model,
81-
medusa_num_heads=4,
82-
medusa_num_layers=1,
83-
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
55+
config,
8456
):
8557
"""
8658
Args:
87-
base_model (nn.Module): The base language model to be used.
88-
medusa_num_heads (int, optional): Number of additional tokens to predict. Defaults to 3.
89-
medusa_num_layers (int, optional): Number of ResBlock layers for each Medusa head. Defaults to 0.
90-
"""
91-
super().__init__()
92-
self.base_model = base_model
93-
self.config = base_model.config
94-
self.hidden_size = base_model.lm_head.weight.shape[-1]
95-
self.vocab_size = base_model.lm_head.weight.shape[0]
59+
config (PretrainedConfig): The configuration of the MedusaModel.
60+
"""
61+
# Load the base model
62+
super().__init__(config)
63+
# For compatibility with the old APIs
64+
65+
medusa_num_heads = config.medusa_num_heads
66+
medusa_num_layers = config.medusa_num_layers
67+
base_model_name_or_path = config._name_or_path
68+
self.hidden_size = config.hidden_size
69+
self.vocab_size = config.vocab_size
9670
self.medusa = medusa_num_heads
9771
self.medusa_num_layers = medusa_num_layers
9872
self.base_model_name_or_path = base_model_name_or_path
@@ -108,79 +82,44 @@ def __init__(
10882
]
10983
)
11084

111-
# Ensure medusa_head's dtype and device align with the base_model
112-
self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)
113-
114-
for i in range(medusa_num_heads):
115-
# Initialize the weights of each medusa_head using the base model's weights
116-
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
117-
85+
# Add a link named base_model to self
86+
@property
87+
def base_model(self):
88+
return self
89+
11890
def get_tokenizer(self):
11991
"""Get the tokenizer of the base model.
12092
12193
Returns:
12294
Tokenizer: The tokenizer of the base model.
12395
"""
12496
return self.tokenizer
125-
97+
12698
@classmethod
12799
def from_pretrained(
128100
cls,
129-
medusa_head_name_or_path,
130-
base_model=None,
131-
medusa_num_heads=None,
101+
pretrained_model_name_or_path,
102+
*args,
132103
**kwargs,
133104
):
134-
"""
135-
Args:
136-
medusa_head_name_or_path (str): Name or path of the Medusa head to load.
137-
**kwargs: Additional keyword arguments for loading the base model.
138-
139-
Returns:
140-
MedusaModel: A MedusaModel instance loaded from the given path.
141-
"""
142-
medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
143-
if medusa_num_heads is not None:
144-
print("Overriding medusa_num_heads as:", medusa_num_heads)
145-
medusa_config.medusa_num_heads = medusa_num_heads
146-
if base_model is not None:
147-
print("Overriding base_model as:", base_model)
148-
medusa_config.base_model_name_or_path = base_model
149-
150-
# hardcode
151-
if 'mistral' or 'zephyr' in medusa_config.base_model_name_or_path:
152-
base_model = KVMistralForCausalLM.from_pretrained(
153-
medusa_config.base_model_name_or_path, **kwargs
154-
)
155-
else:
156-
base_model = KVLlamaForCausalLM.from_pretrained(
157-
medusa_config.base_model_name_or_path, **kwargs
158-
)
159-
160-
model = cls(
161-
base_model,
162-
medusa_config.medusa_num_heads,
163-
medusa_config.medusa_num_layers,
164-
medusa_config.base_model_name_or_path,
105+
# Manually load config to ensure that the medusa_num_heads parameter is loaded
106+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
107+
return super().from_pretrained(
108+
pretrained_model_name_or_path,
109+
*args,
110+
**kwargs,
111+
config=config,
165112
)
166-
medusa_head_path = os.path.join(medusa_head_name_or_path, "medusa_lm_head.pt")
167-
if os.path.exists(medusa_head_path):
168-
filename = medusa_head_path
169-
else:
170-
filename = hf_hub_download(medusa_head_name_or_path, "medusa_lm_head.pt")
171-
medusa_head_state_dict = torch.load(filename, map_location=base_model.device)
172-
model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False)
173-
174-
return model
175113

176114
def forward(
177115
self,
178116
input_ids=None,
179117
attention_mask=None,
180-
labels=None,
181118
past_key_values=None,
182119
output_orig=False,
183120
position_ids=None,
121+
medusa_forward=False,
122+
**kwargs,
184123
):
185124
"""Forward pass of the MedusaModel.
186125
@@ -196,13 +135,22 @@ def forward(
196135
torch.Tensor: A tensor containing predictions from all Medusa heads.
197136
(Optional) Original predictions from the base model's LM head.
198137
"""
138+
if not medusa_forward:
139+
return super().forward(
140+
input_ids=input_ids,
141+
attention_mask=attention_mask,
142+
past_key_values=past_key_values,
143+
position_ids=position_ids,
144+
**kwargs,
145+
)
199146
with torch.inference_mode():
200147
# Pass input through the base model
201148
outputs = self.base_model.model(
202149
input_ids=input_ids,
203150
attention_mask=attention_mask,
204151
past_key_values=past_key_values,
205152
position_ids=position_ids,
153+
**kwargs,
206154
)
207155
if output_orig:
208156
orig = self.base_model.lm_head(outputs[0])
@@ -337,3 +285,6 @@ def medusa_generate(
337285

338286
if self.tokenizer.eos_token_id in input_ids[0, input_len:]:
339287
break
288+
289+
# Currently only support LlamaModel
290+
MedusaModel = MedusaLlamaModel

0 commit comments

Comments
 (0)