Skip to content

Commit fc0a6c7

Browse files
authored
Merge pull request #71 from FasterDecoding/v1.0-prerelease
V1.0 prerelease
2 parents dd9c8a5 + bd95225 commit fc0a6c7

30 files changed

+5496
-714
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,7 @@ test_medusa*
171171
notebooks/test*.ipynb
172172
notebooks/*.pdf
173173
llm_judge/*.sh
174-
llm_judge/data/mt_bench_test
174+
llm_judge/data/mt_bench_test
175+
llm_judge/data/mt_bench_test_rs
176+
data
177+
medusa/eval/*.sh

CITATION.cff

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,22 @@
1-
# This CITATION.cff file was generated with cffinit.
2-
# Visit https://bit.ly/cffinit to generate yours today!
3-
41
cff-version: 1.2.0
5-
title: 'Medusa'
6-
message: >-
7-
If you use this software, please cite it using the
8-
metadata from this file.
9-
type: software
10-
authors:
11-
- given-names: Tianle
12-
family-names: Cai
13-
- given-names: Yuhong
14-
family-names: Li
15-
- given-names: Zhengyang
16-
family-names: Geng
17-
- given-names: Hongwu
18-
family-names: Peng
19-
- given-names: Tri
20-
family-names: Dao
21-
repository-code: 'https://github.com/FasterDecoding/Medusa'
22-
url: 'https://sites.google.com/view/medusa-llm'
23-
abstract: >-
24-
Medusa: Simple Framework for Accelerating LLM Generation
25-
with Multiple Decoding Heads
26-
license: Apache-2.0
27-
date-released: '2023-09-10'
2+
message: "If you use this software, please cite it as below."
3+
references:
4+
- type: article
5+
authors:
6+
- family-names: Cai
7+
given-names: Tianle
8+
- family-names: Li
9+
given-names: Yuhong
10+
- family-names: Geng
11+
given-names: Zhengyang
12+
- family-names: Peng
13+
given-names: Hongwu
14+
- family-names: Lee
15+
given-names: Jason D.
16+
- family-names: Chen
17+
given-names: Deming
18+
- family-names: Dao
19+
given-names: Tri
20+
title: "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads"
21+
year: 2024
22+
journal: "arXiv preprint arXiv: 2401.10774"

README.md

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
<p align="center">
44
| <a href="https://sites.google.com/view/
5-
medusa-llm"><b>Blog</b></a> | <a href="ROADMAP.md"><b>Roadmap</b></a> |
5+
medusa-llm"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2401.10774"><b>Report</b></a> | <a href="ROADMAP.md"><b>Roadmap</b></a> |
66
</p>
77

88
---
99
*News* 🔥
10-
- [2023/09] Medusa won the [Chai Prize Grant](https://twitter.com/tianle_cai/status/1703891335147897341)🎉 The prize will be used as a development bounty for those who help us achieve milestones in our [roadmap](https://github.com/FasterDecoding/Medusa/issues/3)!
11-
- [2023/09] Medusa v0.1 is released!
10+
- [2024/1] Medusa technical report is now available on [arXiv](https://arxiv.org/abs/2401.10774). We've added multiple new features, including Medusa-2 recipe for full-model training, self-distillation for adding Medusa to any fine-tuned LLM, etc. The new results show a 2.2-3.6x speedup over the original model on a range of LLMs.
1211

1312
---
1413
## Introduction
@@ -21,7 +20,7 @@ Medusa is a simple framework that democratizes the acceleration techniques for L
2120
</picture>
2221
<br>
2322
<div align="center" width="80%">
24-
<em>Medusa on Vicuna-7b.</em>
23+
<em>Medusa-1 on Vicuna-7b.</em>
2524
</div>
2625
<br>
2726
</div>
@@ -50,19 +49,25 @@ We aim to solve the challenges associated with speculative decoding by implement
5049
- Instead of introducing a new model, we train multiple decoding heads on the *same* model.
5150
- The training is parameter-efficient so that even the "GPU-Poor" can do it. And since there is no additional model, there is no need to adjust the distributed computing setup.
5251
- Relaxing the requirement of matching the distribution of the original model makes the non-greedy generation even faster than greedy decoding.
52+
53+
In the initial release, our primary focus is on optimizing Medusa for a batch size of 1—a setting commonly utilized for local model hosting. In this configuration, Medusa delivers approximately a 2x speed increase across a range of Vicuna models. We are actively working to extend Medusa's capabilities by integrating it into additional inference frameworks, with the aim of achieving even greater performance gains and extending Medusa to broader settings.
54+
5355
<p align="center">
5456
<picture>
55-
<img src="assets/size_speedup.png" width="45%">
57+
<img src="assets/medusa_speedup_cmp.jpg" width="45%">
5658
</picture>
5759
</p>
58-
In this initial release, our primary focus is on optimizing Medusa for a batch size of 1—a setting commonly utilized for local model hosting. In this configuration, Medusa delivers approximately a 2x speed increase across a range of Vicuna models. We are actively working to extend Medusa's capabilities by integrating it into additional inference frameworks, with the aim of achieving even greater performance gains and extending Medusa to broader settings.
60+
61+
In the updated version, we add support for full-model training, called Medusa-2 (compared to Medusa-1, which only trains the new heads), which requires a special recipe that adds the speculative prediction ability while keeping the original model's performance.
62+
63+
We also add support for self-distillation, which allows us to add Medusa to any fine-tuned LLM without requiring the availability of the original training data.
5964

6065
## Contents
6166
- [Introduction](#introduction)
6267
- [Contents](#contents)
6368
- [Installation](#installation)
6469
- [Method 1: With pip](#method-1-with-pip)
65-
- [Method 2: From source](#method-2-from-source)
70+
- [Method 2: From source (recommended)](#method-2-from-source)
6671
- [Model Weights](#model-weights)
6772
- [Inference](#inference)
6873
- [Training](#training)
@@ -71,28 +76,39 @@ In this initial release, our primary focus is on optimizing Medusa for a batch s
7176
- [Push to Hugging Face Hub](#push-to-hugging-face-hub)
7277
- [Citation](#citation)
7378
- [Codebase Guide](#codebase-guide)
79+
- [Community Adoption](#community-adoption)
7480
- [Contributing](#contributing)
7581
- [Acknowledgements](#acknowledgements)
7682

7783
## Installation
78-
### Method 1: With pip
84+
### Method 1: With pip (may not be the latest version)
7985
```bash
8086
pip install medusa-llm
8187
```
82-
### Method 2: From the source
88+
### Method 2: From the source (recommended)
8389
```bash
8490
git clone https://github.com/FasterDecoding/Medusa.git
8591
cd Medusa
8692
pip install -e .
8793
```
8894

8995
### Model Weights
96+
#### Medusa-1
9097
| Size | Chat Command | Hugging Face Repo |
9198
| ---- | --------------------------------------------- | --------------------------------------------------------------------- |
9299
| 7B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-7b-v1.3` | [FasterDecoding/medusa-vicuna-7b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-7b-v1.3) |
93100
| 13B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-13b-v1.3` | [FasterDecoding/medusa-vicuna-13b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-13b-v1.3) |
94101
| 33B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-33b-v1.3` | [FasterDecoding/medusa-vicuna-33b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-33b-v1.3) |
95102

103+
#### Medusa-2
104+
| Size | Chat Command | Hugging Face Repo |
105+
| ---- | --------------------------------------------- | --------------------------------------------------------------------- |
106+
| Zephyr-7B-Beta | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-zephyr-7b-beta` | [FasterDecoding/medusa-1.0-zephyr-7b-beta](https://huggingface.co/FasterDecoding/medusa-1.0-zephyr-7b-beta) |
107+
| Vicuna-7B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-7b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-7b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-7b-v1.5) |
108+
| Vicuna-13B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-13b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-13b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-13b-v1.5) |
109+
| Vicuna-33B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-33b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-33b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-33b-v1.5) |
110+
111+
96112
### Inference
97113
We currently support single-GPU inference with a batch size of 1, which is the most common setup for local model hosting. We are actively working to extend Medusa's capabilities by integrating it into other inference frameworks; please don't hesitate to reach out if you are interested in contributing to this effort.
98114

@@ -103,6 +119,11 @@ CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [path of medusa mo
103119
You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format. If you download the base model elsewhere, you may override base model name or path with `--base-model [path of base model]`.
104120

105121
### Training
122+
In the updated version, we use the amazing [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library to manage the training process. Please refer to our [fork](https://github.com/ctlllll/axolotl) for the training code. The major code modifications are in [`src/axolotl/utils/models.py`](https://github.com/ctlllll/axolotl/blob/main/src/axolotl/utils/models.py). The training configs can be found in [`examples/medusa`](https://github.com/ctlllll/axolotl/tree/main/examples/medusa).
123+
124+
The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo.
125+
126+
### Training (legacy)
106127
For training, please install:
107128
```bash
108129
pip install -e ".[train]"
@@ -148,13 +169,11 @@ python -m medusa.hf_utils --folder [path of the model folder] --repo [name of th
148169

149170
## Citation
150171
```bibtex
151-
@misc{medusa,
152-
author = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Tri Dao},
153-
title = {Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads},
154-
year = {2023},
155-
publisher = {GitHub},
156-
journal = {GitHub repository},
157-
howpublished = {\url{https://github.com/FasterDecoding/Medusa}},
172+
@article{cai2024medusa,
173+
title = {Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads},
174+
author = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Jason D. Lee and Deming Chen and Tri Dao},
175+
year = {2024},
176+
journal = {arXiv preprint arXiv: 2401.10774}
158177
}
159178
```
160179

@@ -163,8 +182,16 @@ python -m medusa.hf_utils --folder [path of the model folder] --repo [name of th
163182

164183
We also provide some illustrative notebooks in `notebooks/` to help you understand the codebase.
165184

185+
## Community Adoption
186+
We are super excited to see that Medusa has been adopted by many open-source projects. Here is an (incomplete) list:
187+
- [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/medusa)
188+
- [TGI](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/medusa.py)
189+
We are grateful to the authors for their contributions to the community and sincerely hope that Medusa can help accelerate the development of LLMs. If you are using Medusa in your project, please let us know, and we will add your project to the list.
190+
166191
## Contributing
167192
We welcome community contributions to Medusa. If you have an idea for how to improve it, please open an issue to discuss it with us. When submitting a pull request, please ensure that your changes are well-tested. Please split each major change into a separate pull request. We also have a [Roadmap](ROADMAP.md) summarizing our future plans for Medusa. Don't hesitate to reach out if you are interested in contributing to any of the items on the roadmap.
168193

169194
## Acknowledgements
170-
This codebase is influenced by remarkable projects from the LLM community, including [FastChat](https://github.com/lm-sys/FastChat), [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/), [vllm](https://github.com/vllm-project/vllm) and many others.
195+
This codebase is influenced by remarkable projects from the LLM community, including [FastChat](https://github.com/lm-sys/FastChat), [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/), [vllm](https://github.com/vllm-project/vllm), [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl).
196+
197+
This project is supported by [Together AI](https://together.ai/), [MyShell AI](https://myshell.ai/), [Chai AI](https://www.chai-research.com/).

ROADMAP.md

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Roadmap
22

33
## Functionality
4-
- [ ] Explore tree sparsity
5-
- [ ] Fine-tune Medusa heads together with LM head from scratch
6-
- [ ] Distill from any model without access to the original training data
74
- [ ] Batched inference
85
- [ ] Fine-grained KV cache management
6+
- [x] Explore tree sparsity
7+
- [x] Fine-tune Medusa heads together with LM head from scratch
8+
- [x] Distill from any model without access to the original training data
99

1010
## Integration
1111
### Local Deployment
@@ -14,9 +14,6 @@
1414
- [ ] [llama.cpp](https://github.com/ggerganov/llama.cpp)
1515
### Serving
1616
- [ ] [vllm](https://github.com/vllm-project/vllm)
17-
- [ ] [TGI](https://github.com/huggingface/text-generation-inference)
1817
- [ ] [lightllm](https://github.com/ModelTC/lightllm)
19-
20-
## Research
21-
- [x] Optimize the tree-based attention to reduce additional computation
22-
- [ ] Improve the acceptance scheme to generate more diverse sequences
18+
- [x] [TGI](https://github.com/huggingface/text-generation-inference)
19+
- [x] [TensorRT](https://github.com/NVIDIA/TensorRT-LLM)

assets/medusa_pipeline.jpg

-795 KB
Loading

assets/medusa_speedup_cmp.jpg

55.2 KB
Loading

data_generation/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Generate chat data for self-distillation
2+
We use vLLM to enable batched generation. First, install dependencies:
3+
```bash
4+
pip install vllm openai
5+
```
6+
7+
## Start server
8+
9+
```bash
10+
python -m vllm.entrypoints.openai.api_server \
11+
--model YOUR_MODEL_NAME --port 8000
12+
```
13+
You can also start multiple servers with different ports to enable parallel generation. In `generate.py`, we scan the ports from 8000 to 8009 to find available servers. You can modify the code to use other ports.
14+
15+
## Generate data
16+
The following command will let the model to continue the first prompt from each sample in `DATA_PATH`, this is suitable for models that can play both roles in a conversation (e.g., Zephyr 7B). If you want to use all prompts in each sample to repeatly talk to the model, use `--chat` instead. `--chat` mode works for more models but may take longer time to generate due to repeated computation (welcome to contribute a better implementation).
17+
18+
```bash
19+
python generate.py --data_path YOUR_DATA_PATH --output_path YOUR_OUTPUT_PATH --num_threads NUM_THREADS --max_tokens YOUR_MAX_TOKENS --temperature YOUR_TEMPERATURE
20+
```
21+
22+
## (Optional) Format data
23+
When generated with `--chat`, the output file will follow the ShareGPT format ([example](https://github.com/lm-sys/FastChat/blob/main/data/dummy_conversation.json)).
24+
You can use the following command to convert the generated text withour `--chat` to the same format:
25+
```bash
26+
python convert_to_sharegpt.py --input_path YOUR_INPUT_PATH --model_name YOUR_MODEL_NAME --output_path YOUR_OUTPUT_PATH
27+
```
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import json
2+
import os
3+
import time
4+
import concurrent.futures
5+
6+
import openai
7+
import shortuuid
8+
import tqdm
9+
10+
import argparse
11+
import random
12+
13+
from tenacity import (
14+
retry,
15+
stop_after_attempt,
16+
wait_random_exponential,
17+
)
18+
19+
from fastchat.conversation import Conversation, SeparatorStyle
20+
from fastchat.model.model_adapter import get_conversation_template
21+
from transformers import AutoTokenizer
22+
23+
# Use the same arguments as in generate.py
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument("--input_path", type=str)
26+
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-beta")
27+
args = parser.parse_args()
28+
29+
conv = get_conversation_template(args.model_name)
30+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
31+
32+
data = []
33+
with open(args.input_path) as f:
34+
for line in f.readlines():
35+
data.append(json.loads(line))
36+
37+
def convert(text):
38+
messages = []
39+
40+
for turn in text.split(conv.roles[0]):
41+
pairs = turn.split(conv.roles[1])
42+
if len(pairs) != 2:
43+
continue
44+
messages.append({
45+
"from": "human",
46+
"value": pairs[0].split(conv.sep)[0].strip()
47+
})
48+
messages.append({
49+
"from": "gpt",
50+
"value": pairs[1].split(conv.sep)[0].strip()
51+
})
52+
# pop the last message because it might be incomplete
53+
if len(messages) > 0:
54+
messages.pop()
55+
# make sure number of messages is even
56+
if len(messages) % 2 == 1:
57+
messages.pop()
58+
return {"conversations": messages}
59+
60+
sharegpt_data = []
61+
for d in tqdm.tqdm(data):
62+
sample = convert(d["text"])
63+
if len(sample["conversations"]) < 2:
64+
continue
65+
sharegpt_data.append(sample)
66+
67+
# dump to jsonl
68+
with open(args.input_path.replace(".jsonl", "_sharegpt.jsonl"), "w") as f:
69+
for d in sharegpt_data:
70+
f.write(json.dumps(d) + "\n")

0 commit comments

Comments
 (0)