Skip to content

Commit e9d2191

Browse files
committed
Merge branch 'sparse_tree' of github.com:FasterDecoding/Medusa into sparse_tree
2 parents 4054fa2 + 11af0aa commit e9d2191

File tree

4 files changed

+172
-158
lines changed

4 files changed

+172
-158
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,7 @@ wandb/
165165

166166
ShareGPT_Vicuna_unfiltered/
167167

168-
test_medusa*
168+
test_medusa*
169+
170+
# test
171+
notebooks/test*.ipynb

medusa/model/medusa_choices.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]

medusa/model/medusa_model.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
import torch.nn as nn
33
from transformers import PreTrainedModel, PretrainedConfig
44
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
5-
from transformers import AutoTokenizer
65
from .utils import *
76
from .kv_cache import initialize_past_key_values
7+
from .medusa_choices import mc_sim_7b_63
8+
from transformers import AutoTokenizer
89
import os
910
from huggingface_hub import hf_hub_download
1011

1112

1213
class MedusaConfig(PretrainedConfig):
1314
def __init__(
1415
self,
15-
medusa_num_heads=2,
16+
medusa_num_heads=4,
1617
medusa_num_layers=1,
1718
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
1819
**kwargs,
@@ -110,6 +111,7 @@ def get_tokenizer(self):
110111
def from_pretrained(
111112
cls,
112113
medusa_head_name_or_path,
114+
medusa_num_heads=None,
113115
**kwargs,
114116
):
115117
"""
@@ -121,9 +123,12 @@ def from_pretrained(
121123
MedusaModel: A MedusaModel instance loaded from the given path.
122124
"""
123125
medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
126+
if medusa_num_heads is not None:
127+
medusa_config.medusa_num_heads = medusa_num_heads
124128
base_model = KVLlamaForCausalLM.from_pretrained(
125129
medusa_config.base_model_name_or_path, **kwargs
126130
)
131+
127132
model = cls(
128133
base_model,
129134
medusa_config.medusa_num_heads,
@@ -191,7 +196,7 @@ def medusa_generate(
191196
max_steps=512,
192197
# The hyperparameters below are for the Medusa
193198
# top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
194-
medusa_choices=[1, 7, 6],
199+
medusa_choices=mc_sim_7b_63,
195200
posterior_threshold=0.09, # threshold validation of Medusa output
196201
# another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
197202
posterior_alpha=0.3,
@@ -225,7 +230,6 @@ def medusa_generate(
225230
self.medusa_buffers = medusa_buffers
226231
self.medusa_choices = medusa_choices
227232

228-
medusa_topk = medusa_choices[1:]
229233

230234
# Initialize the past key and value states
231235
if hasattr(self, "past_key_values"):
@@ -260,9 +264,8 @@ def medusa_generate(
260264
candidates, tree_candidates = generate_candidates(
261265
medusa_logits,
262266
logits,
263-
medusa_topk,
264267
medusa_buffers["tree_indices"],
265-
temperature,
268+
medusa_buffers["retrieve_indices"],
266269
)
267270

268271
# Use tree attention to verify the candidates and get predictions

0 commit comments

Comments
 (0)