Skip to content

Commit 89f8ec0

Browse files
authored
Merge branch 'main' into sparse_tree
2 parents 5d374e9 + 8ce8fa9 commit 89f8ec0

File tree

7 files changed

+417
-42
lines changed

7 files changed

+417
-42
lines changed

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ medusa-llm"><b>Blog</b></a> | <a href="ROADMAP.md"><b>Roadmap</b></a> |
77

88
---
99
*News* 🔥
10-
- [2023/09] Medusa v0.1 is released! 🎉
10+
- [2023/09] Medusa won the [Chai Prize Grant](https://twitter.com/tianle_cai/status/1703891335147897341)🎉 The prize will be used as a development bounty for those who help us achieve milestones in our [roadmap](https://github.com/FasterDecoding/Medusa/issues/3)!
11+
- [2023/09] Medusa v0.1 is released!
1112

1213
---
1314
## Introduction
@@ -78,7 +79,7 @@ In this initial release, our primary focus is on optimizing Medusa for a batch s
7879
```bash
7980
pip install medusa-llm
8081
```
81-
### Method 2: From source
82+
### Method 2: From the source
8283
```bash
8384
git clone https://github.com/FasterDecoding/Medusa.git
8485
cd Medusa
@@ -95,11 +96,11 @@ pip install -e .
9596
### Inference
9697
We currently support single-GPU inference with a batch size of 1, which is the most common setup for local model hosting. We are actively working to extend Medusa's capabilities by integrating it into other inference frameworks; please don't hesitate to reach out if you are interested in contributing to this effort.
9798

98-
You can use the following command for launching a CLI interface:
99+
You can use the following command to launch a CLI interface:
99100
```bash
100101
CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [path of medusa model]
101102
```
102-
You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format.
103+
You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format. If you download the base model elsewhere, you may override base model name or path with `--base-model [path of base model]`.
103104

104105
### Training
105106
For training, please install:
@@ -111,7 +112,7 @@ We take a public version of the ShareGPT dataset, which is a subset of the Vicun
111112
```bash
112113
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
113114
```
114-
Remark: If you haven't installed `git-lfs`, please install it before clone:
115+
Remark: If you haven't installed `git-lfs`, please install it before cloning:
115116
```bash
116117
git lfs install
117118
```
@@ -158,7 +159,7 @@ python -m medusa.hf_utils --folder [path of the model folder] --repo [name of th
158159
```
159160

160161
## Codebase Guide
161-
`medusa/model/medusa_model.py` is the key file for Medusa. It contains the `MedusaModel` class, which is a wrapper of the original model and the new heads. This class also has implementation of a streaming generation method. If you want to dive into the details of Medusa, this is the place to start.
162+
`medusa/model/medusa_model.py` is the key file for Medusa. It contains the `MedusaModel` class, which is a wrapper of the original model and the new heads. This class also has an implementation of a streaming generation method. If you want to dive into the details of Medusa, this is the place to start.
162163

163164
We also provide some illustrative notebooks in `notebooks/` to help you understand the codebase.
164165

medusa/inference/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def main(args):
3636
try:
3737
model = MedusaModel.from_pretrained(
3838
args.model,
39+
args.base_model,
3940
torch_dtype=torch.float16,
4041
low_cpu_mem_usage=True,
4142
device_map="auto",
@@ -185,6 +186,7 @@ def reload_conv(conv):
185186
if __name__ == "__main__":
186187
parser = argparse.ArgumentParser()
187188
parser.add_argument("--model", type=str, required=True, help="Model name or path.")
189+
parser.add_argument("--base-model", type=str, default=None, help="Base model name or path.")
188190
parser.add_argument(
189191
"--load-in-8bit", action="store_true", help="Use 8-bit quantization"
190192
)

medusa/model/kv_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
class KVCache:
55
"""
66
A key-value cache for the model.
7+
78
This class provides a mechanism to maintain a growing cache of keys and values,
89
particularly useful for models that benefit from caching previous states,
910
like transformers during autoregressive decoding.
@@ -15,6 +16,8 @@ class KVCache:
1516

1617
def __init__(self, data, current_length):
1718
"""
19+
Initialize the KVCache.
20+
1821
Args:
1922
data (torch.Tensor): Initial tensor to store the keys and values.
2023
current_length (int): Initial length of the data.

medusa/model/medusa_model.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111

1212

1313
class MedusaConfig(PretrainedConfig):
14+
"""
15+
Configuration class for Medusa model.
16+
17+
Args:
18+
medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
19+
medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
20+
base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
21+
**kwargs: Additional keyword arguments to be passed to the parent class constructor.
22+
"""
23+
1424
def __init__(
1525
self,
1626
medusa_num_heads=4,
@@ -25,10 +35,14 @@ def __init__(
2535

2636

2737
class ResBlock(nn.Module):
28-
"""A Residual Block module.
38+
"""
39+
A Residual Block module.
2940
3041
This module performs a linear transformation followed by a SiLU activation,
3142
and then adds the result to the original input, creating a residual connection.
43+
44+
Args:
45+
hidden_size (int): The size of the hidden layers in the block.
3246
"""
3347

3448
def __init__(self, hidden_size):
@@ -40,7 +54,8 @@ def __init__(self, hidden_size):
4054
self.act = nn.SiLU()
4155

4256
def forward(self, x):
43-
"""Forward pass of the ResBlock.
57+
"""
58+
Forward pass of the ResBlock.
4459
4560
Args:
4661
x (torch.Tensor): Input tensor.
@@ -112,6 +127,7 @@ def from_pretrained(
112127
cls,
113128
medusa_head_name_or_path,
114129
medusa_num_heads=None,
130+
base_model=None,
115131
**kwargs,
116132
):
117133
"""
@@ -124,7 +140,12 @@ def from_pretrained(
124140
"""
125141
medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
126142
if medusa_num_heads is not None:
143+
print("Overriding medusa_num_heads as:", medusa_num_heads)
127144
medusa_config.medusa_num_heads = medusa_num_heads
145+
if base_model is not None:
146+
print("Overriding base_model as:", base_model)
147+
medusa_config.base_model_name_or_path = base_model
148+
128149
base_model = KVLlamaForCausalLM.from_pretrained(
129150
medusa_config.base_model_name_or_path, **kwargs
130151
)

0 commit comments

Comments
 (0)