Skip to content

Commit 0bfdcd2

Browse files
committed
Adding recipe for other models (non llama, non vicuna).
1 parent 700ff84 commit 0bfdcd2

File tree

4 files changed

+157
-89
lines changed

4 files changed

+157
-89
lines changed

README.md

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ accelerate launch -m axolotl.cli.train examples/medusa/your_config.yml
125125

126126
The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo. For other datasets, you can directly download the data from the corresponding Hugging Face dataset repo.
127127

128-
### Training (legacy)
128+
### Training on various architectures
129129
*The following instructions are for the initial release of Medusa, it provides a minimal example of how to train a Medusa-1 model. For the updated version, please refer to the previous section.*
130130

131131
For training, please install:
@@ -141,14 +141,36 @@ Remark: If you haven't installed `git-lfs`, please install it before cloning:
141141
```bash
142142
git lfs install
143143
```
144+
145+
#### Adapt the data to the model you want to enable medusa on.
146+
147+
Start by launch an inference server you like that will run the model you want to train on.
148+
Let's use [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) as an example.
149+
150+
For instance you can use [text-generation-inference](https://github.com/huggingface/text-generation-inference), which you
151+
can also use after you've trained the medusa heads.
152+
153+
```
154+
model=mistralai/Mistral-7B-Instruct-v0.2
155+
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
156+
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --input-length 4000 --max-total-tokens 4096 --max-batch-prefill-tokens 4000
157+
```
158+
The sequences in shareGPT are relatively long for some, so make sure you can infer on those. If you do not have enough room, the script will simply ignore those long conversation.
159+
It shouldn't impact too much downstream performance, but more data is always better.
160+
You can use various tradeoffs to [speed up inference](https://huggingface.co/docs/text-generation-inference/index) but the defaults show be good enough in most cases.
161+
162+
```
163+
python create_data.py --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json --output-filename mistral.json
164+
```
165+
144166
#### Train the model
145167
We follow the training setup from [FastChat](https://github.com/lm-sys/FastChat#fine-tuning), but with a much larger learning rate because we freeze the original model and only train the new heads. Here is the training command for the Vicuna-7b model on 4 GPUs. Since we are only training the new heads, the training does not require a lot of memory, and only data parallelism is needed. You can modify the script to fit your own setup. For larger models, we use the same setup. You can also use `--load_in_8bit` or `--load_in_4bit` to load the base model in quantized format.
146168
```bash
147-
torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
148-
--data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
169+
torchrun --nproc_per_node=4 medusa/train/train_legacy.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
170+
--data_path mistral.json \
149171
--bf16 True \
150172
--output_dir test \
151-
--num_train_epochs 1 \
173+
--num_train_epochs 2 \
152174
--per_device_train_batch_size 8 \
153175
--per_device_eval_batch_size 8 \
154176
--gradient_accumulation_steps 4 \
@@ -163,7 +185,8 @@ torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vic
163185
--model_max_length 2048 \
164186
--lazy_preprocess True \
165187
--medusa_num_heads 3 \
166-
--medusa_num_layers 1
188+
--medusa_num_layers 1 \
189+
--deepspeed deepspeed.json
167190
```
168191
### Push to Hugging Face Hub
169192
You can use the following command to push your model to the Hugging Face Hub:

create_data.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import typer
2+
import json
3+
from transformers import Conversation
4+
from typing_extensions import Annotated
5+
import httpx
6+
import tqdm
7+
import asyncio
8+
9+
app = typer.Typer()
10+
11+
12+
client = httpx.AsyncClient(timeout=None)
13+
14+
async def run(conv: Conversation):
15+
payload = {"model":"tgi", "messages": conv.messages}
16+
response = await client.post(url, json=payload)
17+
content = response.json()
18+
message = content["choices"][0]["message"]
19+
message.pop("name")
20+
conv.add_message(message)
21+
22+
23+
24+
25+
def fix_source(source):
26+
if source and source[0]["from"] == "gpt":
27+
# Skip if GPT is first to talk
28+
source = source[1:]
29+
new_source = []
30+
for item in source:
31+
role = "assistant" if item["from"] == "gpt" else "user"
32+
content = item["value"]
33+
new_source.append({"role": role, "content": content})
34+
return new_source
35+
36+
37+
async def recreate_conversation(conversation, sem):
38+
async with sem:
39+
conv = Conversation()
40+
try:
41+
for message in conversation[::2]:
42+
assert message["role"] == "user"
43+
conv.add_message(message)
44+
await run(conv)
45+
except Exception:
46+
pass
47+
return conv.messages
48+
49+
@app.command()
50+
def main(
51+
*,
52+
input_filename: Annotated[str, typer.Option("--input-filename")],
53+
output_filename: Annotated[str, typer.Option("--output-filename")],
54+
url: Annotated[str, typer.Option("--url") = "http://localhost:8080/v1/chat/completions",
55+
concurrency: Annotated[int, typer.Option("--concurrency") = 64
56+
):
57+
sem = asyncio.Semaphore(concurrency)
58+
async def _main():
59+
with open(input_filename, "r") as f:
60+
input_data = json.loads(f.read())
61+
conversations = [fix_source(source["conversations"]) for source in input_data]
62+
63+
futures = []
64+
for conversation in conversations:
65+
future = recreate_conversation(conversation, sem)
66+
futures.append(future)
67+
68+
recreated_conversations = await tqdm.asyncio.tqdm.gather(*futures)
69+
70+
with open(output_filename, "w") as f:
71+
json.dump(recreated_conversations, f, indent=4)
72+
asyncio.run(_main())
73+
74+
75+
if __name__ == "__main__":
76+
app()

medusa/model/medusa_model_legacy.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def __init__(
9090
super().__init__()
9191
self.base_model = base_model
9292
self.config = base_model.config
93-
self.hidden_size = base_model.lm_head.weight.shape[-1]
94-
self.vocab_size = base_model.lm_head.weight.shape[0]
93+
self.hidden_size = base_model.config.hidden_size
94+
self.vocab_size = base_model.config.vocab_size
9595
self.medusa = medusa_num_heads
9696
self.medusa_num_layers = medusa_num_layers
9797
self.base_model_name_or_path = base_model_name_or_path
@@ -110,9 +110,12 @@ def __init__(
110110
# Ensure medusa_head's dtype and device align with the base_model
111111
self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)
112112

113-
for i in range(medusa_num_heads):
114-
# Initialize the weights of each medusa_head using the base model's weights
115-
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
113+
import deepspeed
114+
params = [base_model.lm_head.weight]
115+
with deepspeed.zero.GatheredParameters(params):
116+
for i in range(medusa_num_heads):
117+
# Initialize the weights of each medusa_head using the base model's weights
118+
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
116119

117120
def get_tokenizer(self):
118121
"""Get the tokenizer of the base model.
@@ -189,7 +192,7 @@ def forward(
189192
torch.Tensor: A tensor containing predictions from all Medusa heads.
190193
(Optional) Original predictions from the base model's LM head.
191194
"""
192-
with torch.inference_mode():
195+
with torch.no_grad():
193196
# Pass input through the base model
194197
outputs = self.base_model.model(
195198
input_ids=input_ids,

medusa/train/train_legacy.py

Lines changed: 44 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import transformers
3030
from transformers import Trainer, BitsAndBytesConfig
3131
from transformers.trainer_pt_utils import LabelSmoother
32+
from safetensors.torch import save_file
3233

3334
from fastchat.conversation import SeparatorStyle
3435
from fastchat.model.model_adapter import get_conversation_template
@@ -80,7 +81,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
8081
medusa_labels = medusa_labels[not_ignore]
8182

8283
# Add top-k accuracy
83-
for k in range(1, 6):
84+
for k in range(1, 2):
8485
_, topk = medusa_logits.topk(k, dim=-1)
8586
topk = topk[not_ignore]
8687
correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
@@ -119,6 +120,7 @@ class DataArguments:
119120
@dataclass
120121
class TrainingArguments(transformers.TrainingArguments):
121122
cache_dir: Optional[str] = field(default=None)
123+
report_to: Optional[str] = None
122124
optim: str = field(default="adamw_torch")
123125
model_max_length: int = field(
124126
default=2048,
@@ -158,7 +160,6 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: st
158160
del state_dict
159161
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
160162

161-
162163
def preprocess(
163164
sources,
164165
tokenizer: transformers.PreTrainedTokenizer,
@@ -173,73 +174,43 @@ def preprocess(
173174
Returns:
174175
Dict: A dictionary containing tokenized inputs, labels, and attention mask.
175176
"""
176-
conv = get_conversation_template("vicuna")
177-
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
178177

179178
# Apply prompt templates
180179
conversations = []
181-
for i, source in enumerate(sources):
182-
if roles[source[0]["from"]] != conv.roles[0]:
183-
# Skip the first one if it is not from human
184-
source = source[1:]
185-
186-
conv.messages = []
187-
for j, sentence in enumerate(source):
188-
role = roles[sentence["from"]]
189-
assert role == conv.roles[j % 2], f"{i}, {j}, {role}, {conv.roles[j % 2]}"
190-
conv.append_message(role, sentence["value"])
191-
conversations.append(conv.get_prompt())
180+
prompts = []
181+
# # import pdb; pdb.set_trace()
182+
for i, conversation in enumerate(sources):
183+
prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
184+
prompts.append(prompt)
185+
conversations.append(conversation)
192186

193187
# Tokenize conversations
194-
input_ids = tokenizer(
195-
conversations,
188+
encoding = tokenizer(
189+
prompts,
196190
return_tensors="pt",
197191
padding="max_length",
198-
max_length=tokenizer.model_max_length,
199192
truncation=True,
200-
).input_ids
201-
targets = input_ids.clone()
202-
203-
assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
193+
return_offsets_mapping=True,
194+
)
195+
# Set everything to be ignored, except the assistant part
196+
targets = torch.full_like(encoding.input_ids, IGNORE_TOKEN_ID)
197+
input_ids = encoding.input_ids
204198

205199
# Mask targets. Only compute loss on the assistant outputs.
206-
sep = conv.sep + conv.roles[1] + ": "
207-
for conversation, target in zip(conversations, targets):
208-
total_len = int(target.ne(tokenizer.pad_token_id).sum())
209-
210-
turns = conversation.split(conv.sep2)
211-
cur_len = 1
212-
target[:cur_len] = IGNORE_TOKEN_ID
213-
for i, turn in enumerate(turns):
214-
if turn == "":
215-
break
216-
turn_len = len(tokenizer(turn).input_ids)
217-
218-
parts = turn.split(sep)
219-
if len(parts) != 2:
220-
break
221-
parts[0] += sep
222-
# "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
223-
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
224-
225-
# Ignore the user instructions
226-
target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
227-
cur_len += turn_len
228-
229-
target[cur_len:] = IGNORE_TOKEN_ID
230-
231-
if False: # Inspect and check the correctness of masking
232-
z = target.clone()
233-
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
234-
rank0_print(tokenizer.decode(z))
235-
236-
if cur_len < tokenizer.model_max_length:
237-
if cur_len != total_len:
238-
target[:] = IGNORE_TOKEN_ID
239-
rank0_print(
240-
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
241-
f" (ignored)"
242-
)
200+
for conv_index, (conversation, target, prompt) in enumerate(zip(conversations, targets, prompts)):
201+
202+
for turn in conversation:
203+
if turn["role"] == "assistant":
204+
content = turn["content"]
205+
# Unfortunate strip() necessary because chat templates are doing the same.
206+
start = prompt.index(content.strip())
207+
stop = start + len(content)
208+
indices= []
209+
for tok_index, (tok_start, tok_stop) in enumerate(encoding.offset_mapping[conv_index]):
210+
if tok_stop >= start or tok_start < tok_stop:
211+
indices.append(tok_index)
212+
target[indices] = encoding.input_ids[conv_index][indices]
213+
243214

244215
return dict(
245216
input_ids=input_ids,
@@ -260,7 +231,7 @@ def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
260231
super(SupervisedDataset, self).__init__()
261232

262233
rank0_print("Formatting inputs...")
263-
sources = [example["conversations"] for example in raw_data]
234+
sources = raw_data
264235
data_dict = preprocess(sources, tokenizer)
265236

266237
self.input_ids = data_dict["input_ids"]
@@ -304,7 +275,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
304275
if i in self.cached_data_dict:
305276
return self.cached_data_dict[i]
306277

307-
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
278+
ret = preprocess([self.raw_data[i]], self.tokenizer)
308279
ret = dict(
309280
input_ids=ret["input_ids"][0],
310281
labels=ret["labels"][0],
@@ -364,23 +335,12 @@ def train():
364335
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
365336
config.use_cache = False
366337

367-
quantization_config = BitsAndBytesConfig(
368-
load_in_4bit=True,
369-
bnb_4bit_compute_dtype=torch.bfloat16,
370-
bnb_4bit_use_double_quant=True,
371-
bnb_4bit_quant_type="nf4",
372-
)
373-
374338
# Load model and tokenizer
375339
model = transformers.AutoModelForCausalLM.from_pretrained(
376340
model_args.model_name_or_path,
377341
config=config,
378342
cache_dir=training_args.cache_dir,
379-
low_cpu_mem_usage=True,
380343
torch_dtype=torch.bfloat16,
381-
quantization_config=quantization_config if model_args.load_in_4bit else None,
382-
load_in_4bit=model_args.load_in_4bit,
383-
load_in_8bit=model_args.load_in_8bit,
384344
)
385345

386346
# Freeze the base model
@@ -403,7 +363,7 @@ def train():
403363
cache_dir=training_args.cache_dir,
404364
model_max_length=training_args.model_max_length,
405365
padding_side="right",
406-
use_fast=False,
366+
use_fast=True,
407367
)
408368
tokenizer.pad_token = tokenizer.unk_token
409369

@@ -420,7 +380,6 @@ def train():
420380
# Save Medusa config
421381
medusa_config.save_pretrained(training_args.output_dir)
422382

423-
# import pdb; pdb.set_trace()
424383
# Start trainner
425384
trainer = CustomizedTrainer(
426385
model=medusa_lm_head, tokenizer=tokenizer, args=training_args, **data_module
@@ -438,12 +397,19 @@ def train():
438397
lm_head = medusa_lm_head.module.medusa_head
439398
else:
440399
lm_head = medusa_lm_head.medusa_head
400+
import deepspeed
401+
with deepspeed.zero.GatheredParameters(lm_head.parameters()):
402+
state_dict = lm_head.state_dict()
441403

442404
# Save Medusa heads
443-
torch.save(
444-
lm_head.state_dict(),
445-
os.path.join(training_args.output_dir, "medusa_lm_head.pt"),
446-
)
405+
if local_rank == 0:
406+
# Modify the tokenizer internal state before saving.
407+
tokenizer.encode("Test", truncation=None, padding="do_not_pad")
408+
tokenizer.save_pretrained(training_args.output_dir)
409+
save_file(
410+
state_dict,
411+
os.path.join(training_args.output_dir, "medusa_lm_head.safetensors"),
412+
)
447413

448414

449415
if __name__ == "__main__":

0 commit comments

Comments
 (0)