Skip to content

Commit 2b8020f

Browse files
committed
Creating medusa2.
Turns out creating entire weights for the lm_heads costs a huge amount of VRAM (specially for multilingual models like Gemm) and is not necessary at all to get good speculation. This PR modifies the legacy code to create new medusa models without duplicating this lm_head making it much more efficient to run. It also increments the version number of the config so users can know if how to actually run the model.
1 parent 5e98053 commit 2b8020f

File tree

3 files changed

+27
-23
lines changed

3 files changed

+27
-23
lines changed

create_data.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
client = httpx.AsyncClient(timeout=None)
1313

14-
async def run(conv: Conversation):
14+
async def run(conv: Conversation, url: str):
1515
payload = {"model":"tgi", "messages": conv.messages}
1616
response = await client.post(url, json=payload)
1717
content = response.json()
1818
message = content["choices"][0]["message"]
19-
message.pop("name")
19+
message.pop("name", None)
2020
conv.add_message(message)
2121

2222

@@ -34,15 +34,16 @@ def fix_source(source):
3434
return new_source
3535

3636

37-
async def recreate_conversation(conversation, sem):
37+
async def recreate_conversation(conversation, sem, url):
3838
async with sem:
3939
conv = Conversation()
4040
try:
4141
for message in conversation[::2]:
4242
assert message["role"] == "user"
4343
conv.add_message(message)
44-
await run(conv)
45-
except Exception:
44+
await run(conv, url)
45+
except Exception as e:
46+
print(e)
4647
pass
4748
return conv.messages
4849

@@ -62,7 +63,7 @@ async def _main():
6263

6364
futures = []
6465
for conversation in conversations:
65-
future = recreate_conversation(conversation, sem)
66+
future = recreate_conversation(conversation, sem, url)
6667
futures.append(future)
6768

6869
recreated_conversations = await tqdm.asyncio.tqdm.gather(*futures)

medusa/model/medusa_model_legacy.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ def __init__(
2525
self,
2626
medusa_num_heads=4,
2727
medusa_num_layers=1,
28+
version="2",
2829
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
2930
**kwargs,
3031
):
3132
super().__init__(**kwargs)
3233
self.medusa_num_heads = medusa_num_heads
3334
self.medusa_num_layers = medusa_num_layers
35+
self.version = version
3436
self.base_model_name_or_path = base_model_name_or_path
3537

3638

@@ -101,7 +103,6 @@ def __init__(
101103
[
102104
nn.Sequential(
103105
*([ResBlock(self.hidden_size)] * medusa_num_layers),
104-
nn.Linear(self.hidden_size, self.vocab_size, bias=False),
105106
)
106107
for _ in range(medusa_num_heads)
107108
]
@@ -110,13 +111,6 @@ def __init__(
110111
# Ensure medusa_head's dtype and device align with the base_model
111112
self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)
112113

113-
import deepspeed
114-
params = [base_model.lm_head.weight]
115-
with deepspeed.zero.GatheredParameters(params):
116-
for i in range(medusa_num_heads):
117-
# Initialize the weights of each medusa_head using the base model's weights
118-
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
119-
120114
def get_tokenizer(self):
121115
"""Get the tokenizer of the base model.
122116
@@ -207,7 +201,9 @@ def forward(
207201
medusa_logits = []
208202
# TODO: Consider parallelizing this loop for efficiency?
209203
for i in range(self.medusa):
210-
medusa_logits.append(self.medusa_head[i](hidden_states))
204+
mhidden_states = self.medusa_head[i](hidden_states)
205+
mlogits = self.base_model.lm_head(mhidden_states)
206+
medusa_logits.append(mlogits)
211207
if output_orig:
212208
return torch.stack(medusa_logits, dim=0), outputs, orig
213209
return torch.stack(medusa_logits, dim=0)

medusa/train/train_legacy.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ def train():
335335
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
336336
config.use_cache = False
337337

338+
tokenizer = transformers.AutoTokenizer.from_pretrained(
339+
model_args.model_name_or_path,
340+
cache_dir=training_args.cache_dir,
341+
model_max_length=training_args.model_max_length,
342+
padding_side="right",
343+
use_fast=True,
344+
)
345+
tokenizer.pad_token = tokenizer.unk_token
346+
tokenizer.pad_token = tokenizer.eos_token
347+
348+
# Making sure the tokenizer works before loading the model.
349+
print(tokenizer(["This is a test", "secondary"], padding=True))
350+
print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}]))
351+
338352
# Load model and tokenizer
339353
model = transformers.AutoModelForCausalLM.from_pretrained(
340354
model_args.model_name_or_path,
@@ -358,14 +372,6 @@ def train():
358372
# Format output dir
359373
training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"
360374

361-
tokenizer = transformers.AutoTokenizer.from_pretrained(
362-
model_args.model_name_or_path,
363-
cache_dir=training_args.cache_dir,
364-
model_max_length=training_args.model_max_length,
365-
padding_side="right",
366-
use_fast=True,
367-
)
368-
tokenizer.pad_token = tokenizer.unk_token
369375

370376
# Load data
371377
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
@@ -375,6 +381,7 @@ def train():
375381
medusa_num_heads=training_args.medusa_num_heads,
376382
medusa_num_layers=training_args.medusa_num_layers,
377383
base_model_name_or_path=model_args.model_name_or_path,
384+
version="2"
378385
)
379386

380387
# Save Medusa config

0 commit comments

Comments
 (0)