You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo. For other datasets, you can directly download the data from the corresponding Hugging Face dataset repo.
127
127
128
-
### Training (legacy)
128
+
### Training on various architectures
129
129
*The following instructions are for the initial release of Medusa, it provides a minimal example of how to train a Medusa-1 model. For the updated version, please refer to the previous section.*
130
130
131
131
For training, please install:
@@ -141,14 +141,36 @@ Remark: If you haven't installed `git-lfs`, please install it before cloning:
141
141
```bash
142
142
git lfs install
143
143
```
144
+
145
+
#### Adapt the data to the model you want to enable medusa on.
146
+
147
+
Start by launch an inference server you like that will run the model you want to train on.
148
+
Let's use [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) as an example.
149
+
150
+
For instance you can use [text-generation-inference](https://github.com/huggingface/text-generation-inference), which you
151
+
can also use after you've trained the medusa heads.
152
+
153
+
```
154
+
model=mistralai/Mistral-7B-Instruct-v0.2
155
+
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
156
+
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --input-length 4000 --max-total-tokens 4096 --max-batch-prefill-tokens 4000
157
+
```
158
+
The sequences in shareGPT are relatively long for some, so make sure you can infer on those. If you do not have enough room, the script will simply ignore those long conversation.
159
+
It shouldn't impact too much downstream performance, but more data is always better.
160
+
You can use various tradeoffs to [speed up inference](https://huggingface.co/docs/text-generation-inference/index) but the defaults show be good enough in most cases.
We follow the training setup from [FastChat](https://github.com/lm-sys/FastChat#fine-tuning), but with a much larger learning rate because we freeze the original model and only train the new heads. Here is the training command for the Vicuna-7b model on 4 GPUs. Since we are only training the new heads, the training does not require a lot of memory, and only data parallelism is needed. You can modify the script to fit your own setup. For larger models, we use the same setup. You can also use `--load_in_8bit` or `--load_in_4bit` to load the base model in quantized format.
0 commit comments