Skip to content

Commit 55c77cb

Browse files
committed
update notebook
1 parent e9d2191 commit 55c77cb

File tree

2 files changed

+72
-39
lines changed

2 files changed

+72
-39
lines changed

medusa/model/medusa_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class MedusaModel(nn.Module):
6262
def __init__(
6363
self,
6464
base_model,
65-
medusa_num_heads=2,
65+
medusa_num_heads=4,
6666
medusa_num_layers=1,
6767
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
6868
):

notebooks/medusa_inference_explained.ipynb

Lines changed: 71 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"outputs": [],
1717
"source": [
1818
"import os\n",
19-
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # define GPU id, remove if you want to use all GPUs available\n",
19+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\" # define GPU id, remove if you want to use all GPUs available\n",
2020
"import torch\n",
2121
"from tqdm import tqdm\n",
2222
"import time\n",
@@ -26,6 +26,7 @@
2626
"from medusa.model.medusa_model import MedusaModel\n",
2727
"from medusa.model.kv_cache import *\n",
2828
"from medusa.model.utils import *\n",
29+
"from medusa.model.medusa_choices import *\n",
2930
"import transformers\n",
3031
"from huggingface_hub import hf_hub_download"
3132
]
@@ -55,28 +56,83 @@
5556
" elapsed_time = end - start\n",
5657
" wall_times[key].append(elapsed_time)\n",
5758
"\n",
58-
"def medusa_forward(input_ids, model, tokenizer, medusa_buffers, medusa_topk, temperature, posterior_threshold, posterior_alpha, past_key_values, past_key_values_data, current_length_data, steps = 512):\n",
59+
"def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):\n",
5960
" wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}\n",
6061
" \n",
6162
" with timed(wall_times, 'init'):\n",
62-
" reset_medusa_mode(model)\n",
63+
" if hasattr(model, \"medusa_choices\") and model.medusa_choices == medusa_choices:\n",
64+
" # Load the cached medusa buffer\n",
65+
" medusa_buffers = model.medusa_buffers\n",
66+
" else:\n",
67+
" # Initialize the medusa buffer\n",
68+
" medusa_buffers = generate_medusa_buffers(\n",
69+
" medusa_choices, device=model.base_model.device\n",
70+
" )\n",
71+
" model.medusa_buffers = medusa_buffers\n",
72+
" model.medusa_choices = medusa_choices\n",
73+
"\n",
74+
" # Initialize the past key and value states\n",
75+
" if hasattr(model, \"past_key_values\"):\n",
76+
" past_key_values = model.past_key_values\n",
77+
" past_key_values_data = model.past_key_values_data\n",
78+
" current_length_data = model.current_length_data\n",
79+
" # Reset the past key and value states\n",
80+
" current_length_data.zero_()\n",
81+
" else:\n",
82+
" (\n",
83+
" past_key_values,\n",
84+
" past_key_values_data,\n",
85+
" current_length_data,\n",
86+
" ) = initialize_past_key_values(model.base_model)\n",
87+
" model.past_key_values = past_key_values\n",
88+
" model.past_key_values_data = past_key_values_data\n",
89+
" model.current_length_data = current_length_data\n",
90+
"\n",
6391
" input_len = input_ids.shape[1]\n",
64-
" medusa_logits, logits = initialize_medusa(input_ids, model, medusa_buffers['medusa_attn_mask'], past_key_values)\n",
65-
" \n",
92+
" reset_medusa_mode(model)\n",
93+
" medusa_logits, logits = initialize_medusa(\n",
94+
" input_ids, model, medusa_buffers[\"medusa_attn_mask\"], past_key_values\n",
95+
" )\n",
6696
" new_token = 0\n",
6797
"\n",
68-
" for idx in range(steps): \n",
98+
" for idx in range(max_steps): \n",
6999
" with timed(wall_times, 'medusa'):\n",
70-
" candidates, tree_candidates = generate_candidates(medusa_logits, logits, medusa_topk, medusa_buffers['tree_indices'], temperature)\n",
100+
" candidates, tree_candidates = generate_candidates(\n",
101+
" medusa_logits,\n",
102+
" logits,\n",
103+
" medusa_buffers[\"tree_indices\"],\n",
104+
" medusa_buffers[\"retrieve_indices\"],\n",
105+
" )\n",
71106
"\n",
72107
" with timed(wall_times, 'tree'):\n",
73-
" medusa_logits, logits, outputs = tree_decoding(model, tree_candidates, past_key_values, medusa_buffers['medusa_position_ids'], input_ids, medusa_buffers['retrieve_indices'])\n",
108+
" medusa_logits, logits, outputs = tree_decoding(\n",
109+
" model,\n",
110+
" tree_candidates,\n",
111+
" past_key_values,\n",
112+
" medusa_buffers[\"medusa_position_ids\"],\n",
113+
" input_ids,\n",
114+
" medusa_buffers[\"retrieve_indices\"],\n",
115+
" )\n",
74116
"\n",
75117
" with timed(wall_times, 'posterior'):\n",
76-
" best_candidate, accept_length = evaluate_posterior(logits, candidates, temperature, posterior_threshold, posterior_alpha)\n",
118+
" best_candidate, accept_length = evaluate_posterior(\n",
119+
" logits, candidates, temperature, posterior_threshold, posterior_alpha\n",
120+
" )\n",
77121
" \n",
78122
" with timed(wall_times, 'update'):\n",
79-
" input_ids, logits, medusa_logits, new_token = update_inference_inputs(input_ids, candidates, best_candidate, accept_length, medusa_buffers['retrieve_indices'], outputs, logits, medusa_logits, new_token, past_key_values_data, current_length_data)\n",
123+
" input_ids, logits, medusa_logits, new_token = update_inference_inputs(\n",
124+
" input_ids,\n",
125+
" candidates,\n",
126+
" best_candidate,\n",
127+
" accept_length,\n",
128+
" medusa_buffers[\"retrieve_indices\"],\n",
129+
" outputs,\n",
130+
" logits,\n",
131+
" medusa_logits,\n",
132+
" new_token,\n",
133+
" past_key_values_data,\n",
134+
" current_length_data,\n",
135+
" )\n",
80136
"\n",
81137
" if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():\n",
82138
" break\n",
@@ -102,17 +158,15 @@
102158
"model_name = 'FasterDecoding/medusa-vicuna-7b-v1.3'\n",
103159
"model = MedusaModel.from_pretrained(\n",
104160
" model_name,\n",
161+
" medusa_num_heads = 4,\n",
105162
" torch_dtype=torch.float16,\n",
106163
" low_cpu_mem_usage=True,\n",
107164
" device_map=\"auto\"\n",
108165
")\n",
109166
"tokenizer = model.get_tokenizer()\n",
110167
"\n",
111-
"medusa_choices = torch.tensor([1, 7, 6])\n",
112-
"num_heads = len(medusa_choices) - 1\n",
113-
"medusa_topk = medusa_choices[1:]\n",
114-
"\n",
115-
"medusa_buffers = generate_medusa_buffers(medusa_choices, device=model.base_model.device)"
168+
"medusa_choices = mc_sim_7b_63\n",
169+
"\n"
116170
]
117171
},
118172
{
@@ -135,24 +189,6 @@
135189
"posterior_alpha = 0.3"
136190
]
137191
},
138-
{
139-
"cell_type": "markdown",
140-
"metadata": {},
141-
"source": [
142-
"## Initializing Past Values\n",
143-
"\n",
144-
"We initialize the dedicated cache for past key values."
145-
]
146-
},
147-
{
148-
"cell_type": "code",
149-
"execution_count": null,
150-
"metadata": {},
151-
"outputs": [],
152-
"source": [
153-
"past_key_values, past_key_values_data, current_length_data = initialize_past_key_values(model.base_model)"
154-
]
155-
},
156192
{
157193
"cell_type": "markdown",
158194
"metadata": {},
@@ -192,13 +228,10 @@
192228
" torch.as_tensor(input_ids).cuda(),\n",
193229
" model,\n",
194230
" tokenizer,\n",
195-
" medusa_buffers,\n",
196-
" medusa_topk,\n",
231+
" medusa_choices,\n",
197232
" temperature,\n",
198233
" posterior_threshold,\n",
199234
" posterior_alpha,\n",
200-
" past_key_values,\n",
201-
" past_key_values_data, current_length_data\n",
202235
" )\n",
203236
" output_ids = output_ids[0][len(input_ids[0]) :]\n",
204237
" print(\"Output length:\", output_ids.size(-1))\n",
@@ -299,7 +332,7 @@
299332
"name": "python",
300333
"nbconvert_exporter": "python",
301334
"pygments_lexer": "ipython3",
302-
"version": "3.9.16"
335+
"version": "3.9.18"
303336
},
304337
"orig_nbformat": 4
305338
},

0 commit comments

Comments
 (0)