Skip to content

Commit f22d72f

Browse files
committed
solve the compatibility w v0.1
1 parent e1154c0 commit f22d72f

File tree

1 file changed

+58
-8
lines changed

1 file changed

+58
-8
lines changed

medusa/model/medusa_model.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,28 @@
1616
import os
1717
from huggingface_hub import hf_hub_download
1818

19+
class MedusaConfig(PretrainedConfig):
20+
"""
21+
Configuration class for Medusa model.
22+
23+
Args:
24+
medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
25+
medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
26+
base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
27+
**kwargs: Additional keyword arguments to be passed to the parent class constructor.
28+
"""
29+
30+
def __init__(
31+
self,
32+
medusa_num_heads=5,
33+
medusa_num_layers=1,
34+
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
35+
**kwargs,
36+
):
37+
super().__init__(**kwargs)
38+
self.medusa_num_heads = medusa_num_heads
39+
self.medusa_num_layers = medusa_num_layers
40+
self.base_model_name_or_path = base_model_name_or_path
1941

2042
class ResBlock(nn.Module):
2143
"""
@@ -106,13 +128,34 @@ def from_pretrained(
106128
**kwargs,
107129
):
108130
# Manually load config to ensure that the medusa_num_heads parameter is loaded
109-
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
110-
return super().from_pretrained(
111-
pretrained_model_name_or_path,
112-
*args,
113-
**kwargs,
114-
config=config,
115-
)
131+
try:
132+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
133+
return super().from_pretrained(
134+
pretrained_model_name_or_path,
135+
*args,
136+
**kwargs,
137+
config=config,
138+
)
139+
except:
140+
config = MedusaConfig.from_pretrained(pretrained_model_name_or_path)
141+
base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path)
142+
base_model_config.medusa_num_heads = 5 # TODO: fix the uploaded config (only include 2 heads)
143+
base_model_config.medusa_num_layers = config.medusa_num_layers
144+
model = super().from_pretrained(
145+
config.base_model_name_or_path,
146+
*args,
147+
**kwargs,
148+
config=base_model_config,
149+
)
150+
medusa_head_path = os.path.join(pretrained_model_name_or_path, "medusa_lm_head.pt")
151+
if os.path.exists(medusa_head_path):
152+
filename = medusa_head_path
153+
else:
154+
filename = hf_hub_download(pretrained_model_name_or_path, "medusa_lm_head.pt")
155+
medusa_head_state_dict = torch.load(filename, map_location=model.device)
156+
model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False)
157+
return model
158+
116159

117160
def get_tokenizer(self):
118161
"""Get the tokenizer of the base model.
@@ -326,7 +369,14 @@ def from_pretrained(
326369
**kwargs,
327370
):
328371
# Manually load config to ensure that the medusa_num_heads parameter is loaded
329-
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
372+
try:
373+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
374+
except:
375+
# MEDUSA-v0.1 load
376+
config = MedusaConfig.from_pretrained(pretrained_model_name_or_path)
377+
base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path)
378+
config.model_type = base_model_config.model_type
379+
330380
if config.model_type == "llama":
331381
return MedusaModelLlama.from_pretrained(
332382
pretrained_model_name_or_path,

0 commit comments

Comments
 (0)