Skip to content

Commit 4054fa2

Browse files
committed
Revert "Add sparse tree"
This reverts commit 922689a.
1 parent 922689a commit 4054fa2

File tree

4 files changed

+158
-172
lines changed

4 files changed

+158
-172
lines changed

.gitignore

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,4 @@ wandb/
165165

166166
ShareGPT_Vicuna_unfiltered/
167167

168-
test_medusa*
169-
170-
# test
171-
notebooks/test*.ipynb
168+
test_medusa*

medusa/model/medusa_choices.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

medusa/model/medusa_model.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22
import torch.nn as nn
33
from transformers import PreTrainedModel, PretrainedConfig
44
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
5+
from transformers import AutoTokenizer
56
from .utils import *
67
from .kv_cache import initialize_past_key_values
7-
from .medusa_choices import mc_sim_7b_63
8-
from transformers import AutoTokenizer
98
import os
109
from huggingface_hub import hf_hub_download
1110

1211

1312
class MedusaConfig(PretrainedConfig):
1413
def __init__(
1514
self,
16-
medusa_num_heads=4,
15+
medusa_num_heads=2,
1716
medusa_num_layers=1,
1817
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
1918
**kwargs,
@@ -111,7 +110,6 @@ def get_tokenizer(self):
111110
def from_pretrained(
112111
cls,
113112
medusa_head_name_or_path,
114-
medusa_num_heads=None,
115113
**kwargs,
116114
):
117115
"""
@@ -123,12 +121,9 @@ def from_pretrained(
123121
MedusaModel: A MedusaModel instance loaded from the given path.
124122
"""
125123
medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
126-
if medusa_num_heads is not None:
127-
medusa_config.medusa_num_heads = medusa_num_heads
128124
base_model = KVLlamaForCausalLM.from_pretrained(
129125
medusa_config.base_model_name_or_path, **kwargs
130126
)
131-
132127
model = cls(
133128
base_model,
134129
medusa_config.medusa_num_heads,
@@ -196,7 +191,7 @@ def medusa_generate(
196191
max_steps=512,
197192
# The hyperparameters below are for the Medusa
198193
# top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
199-
medusa_choices=mc_sim_7b_63,
194+
medusa_choices=[1, 7, 6],
200195
posterior_threshold=0.09, # threshold validation of Medusa output
201196
# another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
202197
posterior_alpha=0.3,
@@ -230,6 +225,7 @@ def medusa_generate(
230225
self.medusa_buffers = medusa_buffers
231226
self.medusa_choices = medusa_choices
232227

228+
medusa_topk = medusa_choices[1:]
233229

234230
# Initialize the past key and value states
235231
if hasattr(self, "past_key_values"):
@@ -264,8 +260,9 @@ def medusa_generate(
264260
candidates, tree_candidates = generate_candidates(
265261
medusa_logits,
266262
logits,
263+
medusa_topk,
267264
medusa_buffers["tree_indices"],
268-
medusa_buffers["retrieve_indices"],
265+
temperature,
269266
)
270267

271268
# Use tree attention to verify the candidates and get predictions

0 commit comments

Comments
 (0)