Skip to content

Commit 5e98053

Browse files
authored
Merge pull request #83 from Narsil/recipe_for_other_models
Adding recipe for other models (non llama, non vicuna).
2 parents 700ff84 + 64f4924 commit 5e98053

File tree

5 files changed

+181
-89
lines changed

5 files changed

+181
-89
lines changed

README.md

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ accelerate launch -m axolotl.cli.train examples/medusa/your_config.yml
125125

126126
The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo. For other datasets, you can directly download the data from the corresponding Hugging Face dataset repo.
127127

128-
### Training (legacy)
128+
### Training on various architectures
129129
*The following instructions are for the initial release of Medusa, it provides a minimal example of how to train a Medusa-1 model. For the updated version, please refer to the previous section.*
130130

131131
For training, please install:
@@ -141,14 +141,36 @@ Remark: If you haven't installed `git-lfs`, please install it before cloning:
141141
```bash
142142
git lfs install
143143
```
144+
145+
#### Adapt the data to the model you want to enable medusa on.
146+
147+
Start by launch an inference server you like that will run the model you want to train on.
148+
Let's use [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) as an example.
149+
150+
For instance you can use [text-generation-inference](https://github.com/huggingface/text-generation-inference), which you
151+
can also use after you've trained the medusa heads.
152+
153+
```
154+
model=mistralai/Mistral-7B-Instruct-v0.2
155+
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
156+
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --input-length 4000 --max-total-tokens 4096 --max-batch-prefill-tokens 4000
157+
```
158+
The sequences in shareGPT are relatively long for some, so make sure you can infer on those. If you do not have enough room, the script will simply ignore those long conversation.
159+
It shouldn't impact too much downstream performance, but more data is always better.
160+
You can use various tradeoffs to [speed up inference](https://huggingface.co/docs/text-generation-inference/index) but the defaults show be good enough in most cases.
161+
162+
```
163+
python create_data.py --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json --output-filename mistral.json
164+
```
165+
144166
#### Train the model
145167
We follow the training setup from [FastChat](https://github.com/lm-sys/FastChat#fine-tuning), but with a much larger learning rate because we freeze the original model and only train the new heads. Here is the training command for the Vicuna-7b model on 4 GPUs. Since we are only training the new heads, the training does not require a lot of memory, and only data parallelism is needed. You can modify the script to fit your own setup. For larger models, we use the same setup. You can also use `--load_in_8bit` or `--load_in_4bit` to load the base model in quantized format.
146168
```bash
147-
torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
148-
--data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
169+
torchrun --nproc_per_node=4 medusa/train/train_legacy.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
170+
--data_path mistral.json \
149171
--bf16 True \
150172
--output_dir test \
151-
--num_train_epochs 1 \
173+
--num_train_epochs 2 \
152174
--per_device_train_batch_size 8 \
153175
--per_device_eval_batch_size 8 \
154176
--gradient_accumulation_steps 4 \
@@ -163,7 +185,8 @@ torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vic
163185
--model_max_length 2048 \
164186
--lazy_preprocess True \
165187
--medusa_num_heads 3 \
166-
--medusa_num_layers 1
188+
--medusa_num_layers 1 \
189+
--deepspeed deepspeed.json
167190
```
168191
### Push to Hugging Face Hub
169192
You can use the following command to push your model to the Hugging Face Hub:

create_data.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import typer
2+
import json
3+
from transformers import Conversation
4+
from typing_extensions import Annotated
5+
import httpx
6+
import tqdm
7+
import asyncio
8+
9+
app = typer.Typer()
10+
11+
12+
client = httpx.AsyncClient(timeout=None)
13+
14+
async def run(conv: Conversation):
15+
payload = {"model":"tgi", "messages": conv.messages}
16+
response = await client.post(url, json=payload)
17+
content = response.json()
18+
message = content["choices"][0]["message"]
19+
message.pop("name")
20+
conv.add_message(message)
21+
22+
23+
24+
25+
def fix_source(source):
26+
if source and source[0]["from"] == "gpt":
27+
# Skip if GPT is first to talk
28+
source = source[1:]
29+
new_source = []
30+
for item in source:
31+
role = "assistant" if item["from"] == "gpt" else "user"
32+
content = item["value"]
33+
new_source.append({"role": role, "content": content})
34+
return new_source
35+
36+
37+
async def recreate_conversation(conversation, sem):
38+
async with sem:
39+
conv = Conversation()
40+
try:
41+
for message in conversation[::2]:
42+
assert message["role"] == "user"
43+
conv.add_message(message)
44+
await run(conv)
45+
except Exception:
46+
pass
47+
return conv.messages
48+
49+
@app.command()
50+
def main(
51+
*,
52+
input_filename: Annotated[str, typer.Option("--input-filename")],
53+
output_filename: Annotated[str, typer.Option("--output-filename")],
54+
url: Annotated[str, typer.Option("--url")] = "http://localhost:8080/v1/chat/completions",
55+
concurrency: Annotated[int, typer.Option("--concurrency")] = 64
56+
):
57+
sem = asyncio.Semaphore(concurrency)
58+
async def _main():
59+
with open(input_filename, "r") as f:
60+
input_data = json.loads(f.read())
61+
conversations = [fix_source(source["conversations"]) for source in input_data]
62+
63+
futures = []
64+
for conversation in conversations:
65+
future = recreate_conversation(conversation, sem)
66+
futures.append(future)
67+
68+
recreated_conversations = await tqdm.asyncio.tqdm.gather(*futures)
69+
70+
with open(output_filename, "w") as f:
71+
json.dump(recreated_conversations, f, indent=4)
72+
asyncio.run(_main())
73+
74+
75+
if __name__ == "__main__":
76+
app()

deepspeed.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"bf16": {
3+
"enabled": "auto"
4+
},
5+
6+
"zero_optimization": {
7+
"stage": 3,
8+
"overlap_comm": true,
9+
"contiguous_gradients": true,
10+
"sub_group_size": 1e9,
11+
"reduce_bucket_size": "auto",
12+
"stage3_prefetch_bucket_size": "auto",
13+
"stage3_param_persistence_threshold": "auto",
14+
"stage3_max_live_parameters": 1e9,
15+
"stage3_max_reuse_distance": 1e9,
16+
"stage3_gather_16bit_weights_on_model_save": true
17+
},
18+
19+
"gradient_accumulation_steps": "auto",
20+
"steps_per_print": 2000,
21+
"train_batch_size": "auto",
22+
"train_micro_batch_size_per_gpu": "auto",
23+
"wall_clock_breakdown": false
24+
}

medusa/model/medusa_model_legacy.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def __init__(
9090
super().__init__()
9191
self.base_model = base_model
9292
self.config = base_model.config
93-
self.hidden_size = base_model.lm_head.weight.shape[-1]
94-
self.vocab_size = base_model.lm_head.weight.shape[0]
93+
self.hidden_size = base_model.config.hidden_size
94+
self.vocab_size = base_model.config.vocab_size
9595
self.medusa = medusa_num_heads
9696
self.medusa_num_layers = medusa_num_layers
9797
self.base_model_name_or_path = base_model_name_or_path
@@ -110,9 +110,12 @@ def __init__(
110110
# Ensure medusa_head's dtype and device align with the base_model
111111
self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)
112112

113-
for i in range(medusa_num_heads):
114-
# Initialize the weights of each medusa_head using the base model's weights
115-
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
113+
import deepspeed
114+
params = [base_model.lm_head.weight]
115+
with deepspeed.zero.GatheredParameters(params):
116+
for i in range(medusa_num_heads):
117+
# Initialize the weights of each medusa_head using the base model's weights
118+
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
116119

117120
def get_tokenizer(self):
118121
"""Get the tokenizer of the base model.
@@ -189,7 +192,7 @@ def forward(
189192
torch.Tensor: A tensor containing predictions from all Medusa heads.
190193
(Optional) Original predictions from the base model's LM head.
191194
"""
192-
with torch.inference_mode():
195+
with torch.no_grad():
193196
# Pass input through the base model
194197
outputs = self.base_model.model(
195198
input_ids=input_ids,

0 commit comments

Comments
 (0)