Skip to content

Commit 700ff84

Browse files
committed
update README and add back legacy code for compatibility
1 parent 93bee11 commit 700ff84

File tree

3 files changed

+1744
-8
lines changed

3 files changed

+1744
-8
lines changed

README.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,13 @@ We also add support for self-distillation, which allows us to add Medusa to any
6666
- [Introduction](#introduction)
6767
- [Contents](#contents)
6868
- [Installation](#installation)
69-
- [Method 1: With pip](#method-1-with-pip)
70-
- [Method 2: From source (recommended)](#method-2-from-source)
69+
- [Method 1: With pip (may not be the latest version)](#method-1-with-pip-may-not-be-the-latest-version)
70+
- [Method 2: From the source (recommended)](#method-2-from-the-source-recommended)
7171
- [Model Weights](#model-weights)
7272
- [Inference](#inference)
7373
- [Training](#training)
74-
- [Prepare the data](#prepare-the-data)
75-
- [Train the model](#train-the-model)
76-
- [Push to Hugging Face Hub](#push-to-hugging-face-hub)
74+
- [Training (legacy)](#training-legacy)
75+
- [Push to Hugging Face Hub](#push-to-hugging-face-hub)
7776
- [Citation](#citation)
7877
- [Codebase Guide](#codebase-guide)
7978
- [Community Adoption](#community-adoption)
@@ -119,11 +118,16 @@ CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [path of medusa mo
119118
You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format. If you download the base model elsewhere, you may override base model name or path with `--base-model [path of base model]`.
120119

121120
### Training
122-
In the updated version, we use the amazing [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library to manage the training process. Please refer to our [fork](https://github.com/ctlllll/axolotl) for the training code. The major code modifications are in [`src/axolotl/utils/models.py`](https://github.com/ctlllll/axolotl/blob/main/src/axolotl/utils/models.py). The training configs can be found in [`examples/medusa`](https://github.com/ctlllll/axolotl/tree/main/examples/medusa).
121+
In the updated version, we use the amazing [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library to manage the training process. Please refer to our [fork](https://github.com/ctlllll/axolotl) for the training code. The major code modifications are in [`src/axolotl/utils/models.py`](https://github.com/ctlllll/axolotl/blob/main/src/axolotl/utils/models.py). The training configs can be found in [`examples/medusa`](https://github.com/ctlllll/axolotl/tree/main/examples/medusa). A typical training command is as follows:
122+
```bash
123+
accelerate launch -m axolotl.cli.train examples/medusa/your_config.yml
124+
```
123125

124-
The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo.
126+
The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo. For other datasets, you can directly download the data from the corresponding Hugging Face dataset repo.
125127

126128
### Training (legacy)
129+
*The following instructions are for the initial release of Medusa, it provides a minimal example of how to train a Medusa-1 model. For the updated version, please refer to the previous section.*
130+
127131
For training, please install:
128132
```bash
129133
pip install -e ".[train]"
@@ -161,7 +165,7 @@ torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vic
161165
--medusa_num_heads 3 \
162166
--medusa_num_layers 1
163167
```
164-
#### Push to Hugging Face Hub
168+
### Push to Hugging Face Hub
165169
You can use the following command to push your model to the Hugging Face Hub:
166170
```bash
167171
python -m medusa.hf_utils --folder [path of the model folder] --repo [name of the repo]
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import PreTrainedModel, PretrainedConfig
4+
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
5+
from .utils import *
6+
from .kv_cache import initialize_past_key_values
7+
from .medusa_choices import mc_sim_7b_63
8+
from transformers import AutoTokenizer
9+
import os
10+
from huggingface_hub import hf_hub_download
11+
12+
13+
class MedusaConfig(PretrainedConfig):
14+
"""
15+
Configuration class for Medusa model.
16+
17+
Args:
18+
medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
19+
medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
20+
base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
21+
**kwargs: Additional keyword arguments to be passed to the parent class constructor.
22+
"""
23+
24+
def __init__(
25+
self,
26+
medusa_num_heads=4,
27+
medusa_num_layers=1,
28+
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
29+
**kwargs,
30+
):
31+
super().__init__(**kwargs)
32+
self.medusa_num_heads = medusa_num_heads
33+
self.medusa_num_layers = medusa_num_layers
34+
self.base_model_name_or_path = base_model_name_or_path
35+
36+
37+
class ResBlock(nn.Module):
38+
"""
39+
A Residual Block module.
40+
41+
This module performs a linear transformation followed by a SiLU activation,
42+
and then adds the result to the original input, creating a residual connection.
43+
44+
Args:
45+
hidden_size (int): The size of the hidden layers in the block.
46+
"""
47+
48+
def __init__(self, hidden_size):
49+
super().__init__()
50+
self.linear = nn.Linear(hidden_size, hidden_size)
51+
# Initialize as an identity mapping
52+
torch.nn.init.zeros_(self.linear.weight)
53+
# Use SiLU activation to keep consistent with the Llama model
54+
self.act = nn.SiLU()
55+
56+
def forward(self, x):
57+
"""
58+
Forward pass of the ResBlock.
59+
60+
Args:
61+
x (torch.Tensor): Input tensor.
62+
63+
Returns:
64+
torch.Tensor: Output after the residual connection and activation.
65+
"""
66+
return x + self.act(self.linear(x))
67+
68+
69+
class MedusaModel(nn.Module):
70+
"""The Medusa Language Model Head.
71+
72+
This module creates a series of prediction heads (based on the 'medusa' parameter)
73+
on top of a given base model. Each head is composed of a sequence of residual blocks
74+
followed by a linear layer.
75+
"""
76+
77+
def __init__(
78+
self,
79+
base_model,
80+
medusa_num_heads=4,
81+
medusa_num_layers=1,
82+
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
83+
):
84+
"""
85+
Args:
86+
base_model (nn.Module): The base language model to be used.
87+
medusa_num_heads (int, optional): Number of additional tokens to predict. Defaults to 3.
88+
medusa_num_layers (int, optional): Number of ResBlock layers for each Medusa head. Defaults to 0.
89+
"""
90+
super().__init__()
91+
self.base_model = base_model
92+
self.config = base_model.config
93+
self.hidden_size = base_model.lm_head.weight.shape[-1]
94+
self.vocab_size = base_model.lm_head.weight.shape[0]
95+
self.medusa = medusa_num_heads
96+
self.medusa_num_layers = medusa_num_layers
97+
self.base_model_name_or_path = base_model_name_or_path
98+
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
99+
# Create a list of Medusa heads
100+
self.medusa_head = nn.ModuleList(
101+
[
102+
nn.Sequential(
103+
*([ResBlock(self.hidden_size)] * medusa_num_layers),
104+
nn.Linear(self.hidden_size, self.vocab_size, bias=False),
105+
)
106+
for _ in range(medusa_num_heads)
107+
]
108+
)
109+
110+
# Ensure medusa_head's dtype and device align with the base_model
111+
self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)
112+
113+
for i in range(medusa_num_heads):
114+
# Initialize the weights of each medusa_head using the base model's weights
115+
self.medusa_head[i][-1].weight.data[:] = base_model.lm_head.weight.data[:]
116+
117+
def get_tokenizer(self):
118+
"""Get the tokenizer of the base model.
119+
120+
Returns:
121+
Tokenizer: The tokenizer of the base model.
122+
"""
123+
return self.tokenizer
124+
125+
@classmethod
126+
def from_pretrained(
127+
cls,
128+
medusa_head_name_or_path,
129+
base_model=None,
130+
medusa_num_heads=None,
131+
**kwargs,
132+
):
133+
"""
134+
Args:
135+
medusa_head_name_or_path (str): Name or path of the Medusa head to load.
136+
**kwargs: Additional keyword arguments for loading the base model.
137+
138+
Returns:
139+
MedusaModel: A MedusaModel instance loaded from the given path.
140+
"""
141+
medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
142+
if medusa_num_heads is not None:
143+
print("Overriding medusa_num_heads as:", medusa_num_heads)
144+
medusa_config.medusa_num_heads = medusa_num_heads
145+
if base_model is not None:
146+
print("Overriding base_model as:", base_model)
147+
medusa_config.base_model_name_or_path = base_model
148+
149+
base_model = KVLlamaForCausalLM.from_pretrained(
150+
medusa_config.base_model_name_or_path, **kwargs
151+
)
152+
153+
model = cls(
154+
base_model,
155+
medusa_config.medusa_num_heads,
156+
medusa_config.medusa_num_layers,
157+
medusa_config.base_model_name_or_path,
158+
)
159+
medusa_head_path = os.path.join(medusa_head_name_or_path, "medusa_lm_head.pt")
160+
if os.path.exists(medusa_head_path):
161+
filename = medusa_head_path
162+
else:
163+
filename = hf_hub_download(medusa_head_name_or_path, "medusa_lm_head.pt")
164+
medusa_head_state_dict = torch.load(filename, map_location=base_model.device)
165+
model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False)
166+
167+
return model
168+
169+
def forward(
170+
self,
171+
input_ids=None,
172+
attention_mask=None,
173+
labels=None,
174+
past_key_values=None,
175+
output_orig=False,
176+
position_ids=None,
177+
):
178+
"""Forward pass of the MedusaModel.
179+
180+
Args:
181+
input_ids (torch.Tensor, optional): Input token IDs.
182+
attention_mask (torch.Tensor, optional): Attention mask.
183+
labels (torch.Tensor, optional): Ground truth labels for loss computation.
184+
past_key_values (tuple, optional): Tuple containing past key and value states for attention.
185+
output_orig (bool, optional): Whether to also output predictions from the original LM head.
186+
position_ids (torch.Tensor, optional): Position IDs.
187+
188+
Returns:
189+
torch.Tensor: A tensor containing predictions from all Medusa heads.
190+
(Optional) Original predictions from the base model's LM head.
191+
"""
192+
with torch.inference_mode():
193+
# Pass input through the base model
194+
outputs = self.base_model.model(
195+
input_ids=input_ids,
196+
attention_mask=attention_mask,
197+
past_key_values=past_key_values,
198+
position_ids=position_ids,
199+
)
200+
if output_orig:
201+
orig = self.base_model.lm_head(outputs[0])
202+
# Clone the output hidden states
203+
hidden_states = outputs[0].clone()
204+
medusa_logits = []
205+
# TODO: Consider parallelizing this loop for efficiency?
206+
for i in range(self.medusa):
207+
medusa_logits.append(self.medusa_head[i](hidden_states))
208+
if output_orig:
209+
return torch.stack(medusa_logits, dim=0), outputs, orig
210+
return torch.stack(medusa_logits, dim=0)
211+
212+
def medusa_generate(
213+
self,
214+
input_ids,
215+
attention_mask=None,
216+
temperature=0.0,
217+
max_steps=512,
218+
# The hyperparameters below are for the Medusa
219+
# top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
220+
medusa_choices=mc_sim_7b_63,
221+
posterior_threshold=0.09, # threshold validation of Medusa output
222+
# another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
223+
posterior_alpha=0.3,
224+
):
225+
"""
226+
Args:
227+
input_ids (torch.Tensor, optional): Input token IDs.
228+
attention_mask (torch.Tensor, optional): Attention mask.
229+
temperature (float, optional): Temperature for typical acceptance.
230+
medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head.
231+
posterior_threshold (float, optional): Threshold for posterior validation.
232+
posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold).
233+
Returns:
234+
torch.Tensor: Output token IDs.
235+
236+
Warning: Only support batch size 1 for now!!
237+
"""
238+
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
239+
# Avoid modifying the input_ids in-place
240+
input_ids = input_ids.clone()
241+
242+
# Cache medusa buffers (the fixed patterns for tree attention)
243+
if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices:
244+
# Load the cached medusa buffer
245+
medusa_buffers = self.medusa_buffers
246+
else:
247+
# Initialize the medusa buffer
248+
medusa_buffers = generate_medusa_buffers(
249+
medusa_choices, device=self.base_model.device
250+
)
251+
self.medusa_buffers = medusa_buffers
252+
self.medusa_choices = medusa_choices
253+
254+
255+
# Initialize the past key and value states
256+
if hasattr(self, "past_key_values"):
257+
past_key_values = self.past_key_values
258+
past_key_values_data = self.past_key_values_data
259+
current_length_data = self.current_length_data
260+
# Reset the past key and value states
261+
current_length_data.zero_()
262+
else:
263+
(
264+
past_key_values,
265+
past_key_values_data,
266+
current_length_data,
267+
) = initialize_past_key_values(self.base_model)
268+
self.past_key_values = past_key_values
269+
self.past_key_values_data = past_key_values_data
270+
self.current_length_data = current_length_data
271+
272+
input_len = input_ids.shape[1]
273+
274+
reset_medusa_mode(self)
275+
# Initialize tree attention mask and process prefill tokens
276+
medusa_logits, logits = initialize_medusa(
277+
input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values
278+
)
279+
280+
new_token = 0
281+
last_round_token = 0
282+
283+
for idx in range(max_steps):
284+
# Generate candidates with topk predictions from Medusa heads
285+
candidates, tree_candidates = generate_candidates(
286+
medusa_logits,
287+
logits,
288+
medusa_buffers["tree_indices"],
289+
medusa_buffers["retrieve_indices"],
290+
)
291+
292+
# Use tree attention to verify the candidates and get predictions
293+
medusa_logits, logits, outputs = tree_decoding(
294+
self,
295+
tree_candidates,
296+
past_key_values,
297+
medusa_buffers["medusa_position_ids"],
298+
input_ids,
299+
medusa_buffers["retrieve_indices"],
300+
)
301+
302+
# Evaluate the posterior of the candidates to select the accepted candidate prefix
303+
best_candidate, accept_length = evaluate_posterior(
304+
logits, candidates, temperature, posterior_threshold, posterior_alpha
305+
)
306+
307+
# Update the input_ids and logits
308+
input_ids, logits, medusa_logits, new_token = update_inference_inputs(
309+
input_ids,
310+
candidates,
311+
best_candidate,
312+
accept_length,
313+
medusa_buffers["retrieve_indices"],
314+
outputs,
315+
logits,
316+
medusa_logits,
317+
new_token,
318+
past_key_values_data,
319+
current_length_data,
320+
)
321+
322+
yield {
323+
"text": self.tokenizer.decode(
324+
input_ids[0, input_len:],
325+
skip_special_tokens=True,
326+
spaces_between_special_tokens=False,
327+
clean_up_tokenization_spaces=True,
328+
)
329+
}
330+
331+
if self.tokenizer.eos_token_id in input_ids[0, input_len:]:
332+
break

0 commit comments

Comments
 (0)