Skip to content

Commit e2a5d20

Browse files
authored
Merge pull request #97 from Narsil/medusa2
Creating medusa2.
2 parents 5e98053 + 2b8020f commit e2a5d20

File tree

3 files changed

+27
-23
lines changed

3 files changed

+27
-23
lines changed

create_data.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
client = httpx.AsyncClient(timeout=None)
1313

14-
async def run(conv: Conversation):
14+
async def run(conv: Conversation, url: str):
1515
payload = {"model":"tgi", "messages": conv.messages}
1616
response = await client.post(url, json=payload)
1717
content = response.json()
1818
message = content["choices"][0]["message"]
19-
message.pop("name")
19+
message.pop("name", None)
2020
conv.add_message(message)
2121

2222

@@ -34,15 +34,16 @@ def fix_source(source):
3434
return new_source
3535

3636

37-
async def recreate_conversation(conversation, sem):
37+
async def recreate_conversation(conversation, sem, url):
3838
async with sem:
3939
conv = Conversation()
4040
try:
4141
for message in conversation[::2]:
4242
assert message["role"] == "user"
4343
conv.add_message(message)
44-
await run(conv)
45-
except Exception:
44+
await run(conv, url)
45+
except Exception as e:
46+
print(e)
4647
pass
4748
return conv.messages
4849

@@ -62,7 +63,7 @@ async def _main():
6263

6364
futures = []
6465
for conversation in conversations:
65-
future = recreate_conversation(conversation, sem)
66+
future = recreate_conversation(conversation, sem, url)
6667
futures.append(future)
6768

6869
recreated_conversations = await tqdm.asyncio.tqdm.gather(*futures)

medusa/model/medusa_model_legacy.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ def __init__(
2525
self,
2626
medusa_num_heads=4,
2727
medusa_num_layers=1,
28+
version="2",
2829
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
2930
**kwargs,
3031
):
3132
super().__init__(**kwargs)
3233
self.medusa_num_heads = medusa_num_heads
3334
self.medusa_num_layers = medusa_num_layers
35+
self.version = version
3436
self.base_model_name_or_path = base_model_name_or_path
3537

3638

@@ -101,7 +103,6 @@ def __init__(
101103
[
102104
nn.Sequential(
103105
*([ResBlock(self.hidden_size)] * medusa_num_layers),
104-
nn.Linear(self.hidden_size, self.vocab_size, bias=False),
105106
)
106107
for _ in range(medusa_num_heads)
107108
]
@@ -110,13 +111,6 @@ def __init__(
110111
# Ensure medusa_head's dtype and device align with the base_model
111112
self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)
112113

113-
import deepspeed
114-
params = [base_model.lm_head.weight]
115-
with deepspeed.zero.GatheredParameters(params):
116-
for i in range(medusa_num_heads):
117-
# Initialize the weights of each medusa_head using the base model's weights
118-
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
119-
120114
def get_tokenizer(self):
121115
"""Get the tokenizer of the base model.
122116
@@ -207,7 +201,9 @@ def forward(
207201
medusa_logits = []
208202
# TODO: Consider parallelizing this loop for efficiency?
209203
for i in range(self.medusa):
210-
medusa_logits.append(self.medusa_head[i](hidden_states))
204+
mhidden_states = self.medusa_head[i](hidden_states)
205+
mlogits = self.base_model.lm_head(mhidden_states)
206+
medusa_logits.append(mlogits)
211207
if output_orig:
212208
return torch.stack(medusa_logits, dim=0), outputs, orig
213209
return torch.stack(medusa_logits, dim=0)

medusa/train/train_legacy.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ def train():
335335
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
336336
config.use_cache = False
337337

338+
tokenizer = transformers.AutoTokenizer.from_pretrained(
339+
model_args.model_name_or_path,
340+
cache_dir=training_args.cache_dir,
341+
model_max_length=training_args.model_max_length,
342+
padding_side="right",
343+
use_fast=True,
344+
)
345+
tokenizer.pad_token = tokenizer.unk_token
346+
tokenizer.pad_token = tokenizer.eos_token
347+
348+
# Making sure the tokenizer works before loading the model.
349+
print(tokenizer(["This is a test", "secondary"], padding=True))
350+
print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}]))
351+
338352
# Load model and tokenizer
339353
model = transformers.AutoModelForCausalLM.from_pretrained(
340354
model_args.model_name_or_path,
@@ -358,14 +372,6 @@ def train():
358372
# Format output dir
359373
training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"
360374

361-
tokenizer = transformers.AutoTokenizer.from_pretrained(
362-
model_args.model_name_or_path,
363-
cache_dir=training_args.cache_dir,
364-
model_max_length=training_args.model_max_length,
365-
padding_side="right",
366-
use_fast=True,
367-
)
368-
tokenizer.pad_token = tokenizer.unk_token
369375

370376
# Load data
371377
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
@@ -375,6 +381,7 @@ def train():
375381
medusa_num_heads=training_args.medusa_num_heads,
376382
medusa_num_layers=training_args.medusa_num_layers,
377383
base_model_name_or_path=model_args.model_name_or_path,
384+
version="2"
378385
)
379386

380387
# Save Medusa config

0 commit comments

Comments
 (0)