Skip to content

Commit 46f911e

Browse files
author
Tianle Cai
committed
update model loading APIs
1 parent 6294228 commit 46f911e

File tree

5 files changed

+379
-89
lines changed

5 files changed

+379
-89
lines changed

medusa/model/medusa_model.py

Lines changed: 43 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,11 @@
55
from .utils import *
66
from .kv_cache import initialize_past_key_values
77
from .medusa_choices import mc_sim_7b_63
8-
from transformers import AutoTokenizer
8+
from transformers import AutoTokenizer, AutoConfig
99
import os
1010
from huggingface_hub import hf_hub_download
1111

1212

13-
class MedusaConfig(PretrainedConfig):
14-
"""
15-
Configuration class for Medusa model.
16-
17-
Args:
18-
medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
19-
medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
20-
base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
21-
**kwargs: Additional keyword arguments to be passed to the parent class constructor.
22-
"""
23-
24-
def __init__(
25-
self,
26-
medusa_num_heads=4,
27-
medusa_num_layers=1,
28-
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
29-
**kwargs,
30-
):
31-
super().__init__(**kwargs)
32-
self.medusa_num_heads = medusa_num_heads
33-
self.medusa_num_layers = medusa_num_layers
34-
self.base_model_name_or_path = base_model_name_or_path
35-
36-
3713
class ResBlock(nn.Module):
3814
"""
3915
A Residual Block module.
@@ -65,8 +41,7 @@ def forward(self, x):
6541
"""
6642
return x + self.act(self.linear(x))
6743

68-
69-
class MedusaModel(nn.Module):
44+
class MedusaLlamaModel(KVLlamaForCausalLM):
7045
"""The Medusa Language Model Head.
7146
7247
This module creates a series of prediction heads (based on the 'medusa' parameter)
@@ -76,22 +51,21 @@ class MedusaModel(nn.Module):
7651

7752
def __init__(
7853
self,
79-
base_model,
80-
medusa_num_heads=4,
81-
medusa_num_layers=1,
82-
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
54+
config,
8355
):
8456
"""
8557
Args:
86-
base_model (nn.Module): The base language model to be used.
87-
medusa_num_heads (int, optional): Number of additional tokens to predict. Defaults to 3.
88-
medusa_num_layers (int, optional): Number of ResBlock layers for each Medusa head. Defaults to 0.
89-
"""
90-
super().__init__()
91-
self.base_model = base_model
92-
self.config = base_model.config
93-
self.hidden_size = base_model.lm_head.weight.shape[-1]
94-
self.vocab_size = base_model.lm_head.weight.shape[0]
58+
config (PretrainedConfig): The configuration of the MedusaModel.
59+
"""
60+
# Load the base model
61+
super().__init__(config)
62+
# For compatibility with the old APIs
63+
64+
medusa_num_heads = config.medusa_num_heads
65+
medusa_num_layers = config.medusa_num_layers
66+
base_model_name_or_path = config._name_or_path
67+
self.hidden_size = config.hidden_size
68+
self.vocab_size = config.vocab_size
9569
self.medusa = medusa_num_heads
9670
self.medusa_num_layers = medusa_num_layers
9771
self.base_model_name_or_path = base_model_name_or_path
@@ -107,73 +81,44 @@ def __init__(
10781
]
10882
)
10983

110-
# Ensure medusa_head's dtype and device align with the base_model
111-
self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)
112-
113-
for i in range(medusa_num_heads):
114-
# Initialize the weights of each medusa_head using the base model's weights
115-
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
116-
84+
# Add a link named base_model to self
85+
@property
86+
def base_model(self):
87+
return self
88+
11789
def get_tokenizer(self):
11890
"""Get the tokenizer of the base model.
11991
12092
Returns:
12193
Tokenizer: The tokenizer of the base model.
12294
"""
12395
return self.tokenizer
124-
96+
12597
@classmethod
12698
def from_pretrained(
12799
cls,
128-
medusa_head_name_or_path,
129-
base_model=None,
130-
medusa_num_heads=None,
100+
pretrained_model_name_or_path,
101+
*args,
131102
**kwargs,
132103
):
133-
"""
134-
Args:
135-
medusa_head_name_or_path (str): Name or path of the Medusa head to load.
136-
**kwargs: Additional keyword arguments for loading the base model.
137-
138-
Returns:
139-
MedusaModel: A MedusaModel instance loaded from the given path.
140-
"""
141-
medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
142-
if medusa_num_heads is not None:
143-
print("Overriding medusa_num_heads as:", medusa_num_heads)
144-
medusa_config.medusa_num_heads = medusa_num_heads
145-
if base_model is not None:
146-
print("Overriding base_model as:", base_model)
147-
medusa_config.base_model_name_or_path = base_model
148-
149-
base_model = KVLlamaForCausalLM.from_pretrained(
150-
medusa_config.base_model_name_or_path, **kwargs
151-
)
152-
153-
model = cls(
154-
base_model,
155-
medusa_config.medusa_num_heads,
156-
medusa_config.medusa_num_layers,
157-
medusa_config.base_model_name_or_path,
104+
# Manually load config to ensure that the medusa_num_heads parameter is loaded
105+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
106+
return super().from_pretrained(
107+
pretrained_model_name_or_path,
108+
*args,
109+
**kwargs,
110+
config=config,
158111
)
159-
medusa_head_path = os.path.join(medusa_head_name_or_path, "medusa_lm_head.pt")
160-
if os.path.exists(medusa_head_path):
161-
filename = medusa_head_path
162-
else:
163-
filename = hf_hub_download(medusa_head_name_or_path, "medusa_lm_head.pt")
164-
medusa_head_state_dict = torch.load(filename, map_location=base_model.device)
165-
model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False)
166-
167-
return model
168112

169113
def forward(
170114
self,
171115
input_ids=None,
172116
attention_mask=None,
173-
labels=None,
174117
past_key_values=None,
175118
output_orig=False,
176119
position_ids=None,
120+
medusa_forward=False,
121+
**kwargs,
177122
):
178123
"""Forward pass of the MedusaModel.
179124
@@ -189,13 +134,22 @@ def forward(
189134
torch.Tensor: A tensor containing predictions from all Medusa heads.
190135
(Optional) Original predictions from the base model's LM head.
191136
"""
137+
if not medusa_forward:
138+
return super().forward(
139+
input_ids=input_ids,
140+
attention_mask=attention_mask,
141+
past_key_values=past_key_values,
142+
position_ids=position_ids,
143+
**kwargs,
144+
)
192145
with torch.inference_mode():
193146
# Pass input through the base model
194147
outputs = self.base_model.model(
195148
input_ids=input_ids,
196149
attention_mask=attention_mask,
197150
past_key_values=past_key_values,
198151
position_ids=position_ids,
152+
**kwargs,
199153
)
200154
if output_orig:
201155
orig = self.base_model.lm_head(outputs[0])
@@ -330,3 +284,6 @@ def medusa_generate(
330284

331285
if self.tokenizer.eos_token_id in input_ids[0, input_len:]:
332286
break
287+
288+
# Currently only support LlamaModel
289+
MedusaModel = MedusaLlamaModel

0 commit comments

Comments
 (0)