Skip to content

Commit 60eeab9

Browse files
committed
fix a cli bug
1 parent 10a2697 commit 60eeab9

File tree

5 files changed

+110
-26
lines changed

5 files changed

+110
-26
lines changed

medusa/inference/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def main(args):
3737
model = MedusaModel.from_pretrained(
3838
args.model,
3939
args.base_model,
40+
medusa_num_heads = 4,
4041
torch_dtype=torch.float16,
4142
low_cpu_mem_usage=True,
4243
device_map="auto",

medusa/model/medusa_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def get_tokenizer(self):
126126
def from_pretrained(
127127
cls,
128128
medusa_head_name_or_path,
129-
medusa_num_heads=None,
130129
base_model=None,
130+
medusa_num_heads=None,
131131
**kwargs,
132132
):
133133
"""

notebooks/medusa_configuration_explained.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
{
122122
"data": {
123123
"text/plain": [
124-
"<matplotlib.image.AxesImage at 0x7f79c4e723d0>"
124+
"<matplotlib.image.AxesImage at 0x7f1b60e80520>"
125125
]
126126
},
127127
"execution_count": 6,

notebooks/medusa_inference_explained.ipynb

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": null,
14+
"execution_count": 1,
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
@@ -42,7 +42,7 @@
4242
},
4343
{
4444
"cell_type": "code",
45-
"execution_count": null,
45+
"execution_count": 2,
4646
"metadata": {},
4747
"outputs": [],
4848
"source": [
@@ -151,9 +151,38 @@
151151
},
152152
{
153153
"cell_type": "code",
154-
"execution_count": null,
154+
"execution_count": 3,
155155
"metadata": {},
156-
"outputs": [],
156+
"outputs": [
157+
{
158+
"name": "stdout",
159+
"output_type": "stream",
160+
"text": [
161+
"Overriding medusa_num_heads as: 4\n"
162+
]
163+
},
164+
{
165+
"data": {
166+
"application/vnd.jupyter.widget-view+json": {
167+
"model_id": "ef69040c760f4e4b949e27b2c09526d2",
168+
"version_major": 2,
169+
"version_minor": 0
170+
},
171+
"text/plain": [
172+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
173+
]
174+
},
175+
"metadata": {},
176+
"output_type": "display_data"
177+
},
178+
{
179+
"name": "stderr",
180+
"output_type": "stream",
181+
"text": [
182+
"You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
183+
]
184+
}
185+
],
157186
"source": [
158187
"model_name = 'FasterDecoding/medusa-vicuna-7b-v1.3'\n",
159188
"model = MedusaModel.from_pretrained(\n",
@@ -180,7 +209,7 @@
180209
},
181210
{
182211
"cell_type": "code",
183-
"execution_count": null,
212+
"execution_count": 4,
184213
"metadata": {},
185214
"outputs": [],
186215
"source": [
@@ -200,7 +229,7 @@
200229
},
201230
{
202231
"cell_type": "code",
203-
"execution_count": null,
232+
"execution_count": 5,
204233
"metadata": {},
205234
"outputs": [],
206235
"source": [
@@ -218,9 +247,18 @@
218247
},
219248
{
220249
"cell_type": "code",
221-
"execution_count": null,
250+
"execution_count": 9,
222251
"metadata": {},
223-
"outputs": [],
252+
"outputs": [
253+
{
254+
"name": "stdout",
255+
"output_type": "stream",
256+
"text": [
257+
"Output length: 403\n",
258+
"Compression ratio: tensor(2.4724, device='cuda:0')\n"
259+
]
260+
}
261+
],
224262
"source": [
225263
"with torch.inference_mode():\n",
226264
" input_ids = tokenizer([prompt]).input_ids\n",
@@ -249,9 +287,27 @@
249287
},
250288
{
251289
"cell_type": "code",
252-
"execution_count": null,
290+
"execution_count": 10,
253291
"metadata": {},
254-
"outputs": [],
292+
"outputs": [
293+
{
294+
"name": "stdout",
295+
"output_type": "stream",
296+
"text": [
297+
"Once upon a time, in a small village nestled in the Andes mountains, there lived a charming llama named Luna. Luna was known for her kind heart and her love of coffee. She would often spend her afternoons sipping on a steaming cup of joe at the local café, chatting with the villagers and enjoying the warmth of the sun on her back.\n",
298+
"\n",
299+
"One day, as Luna was grazing on some fresh grass, she noticed that her hair was starting to grow longer and thicker. At first, she didn't think much of it, but as the days went on, her hair continued to grow and change. It became thick and wiry, with sharp spikes protruding from it.\n",
300+
"\n",
301+
"Luna was confused and a little scared by her new appearance. She had always been a gentle creature, and now she looked like a monster. She knew that she couldn't stay in the village anymore, so she set off on a journey to find a new home.\n",
302+
"\n",
303+
"As she wandered through the mountains, Luna stumbled upon a beautiful clearing. In the center of the clearing stood a small cottage, with a sign hanging outside that read \"Café Llama.\" Luna knew that this was where she belonged.\n",
304+
"\n",
305+
"She transformed the cottage into a cozy coffee shop, serving the best coffee in the mountains. The villagers were amazed by Luna's transformation, and they flocked to her café to taste her delicious brews.\n",
306+
"\n",
307+
"Luna's Medusa-like hair became her signature style, and she quickly became known as the most charming llama in the land. She spent her days sipping coffee, chatting with customers, and enjoying the warmth of the sun on her back. And she knew that she had finally found her true home.</s>\n"
308+
]
309+
}
310+
],
255311
"source": [
256312
"output = tokenizer.decode(\n",
257313
" output_ids,\n",
@@ -275,9 +331,30 @@
275331
},
276332
{
277333
"cell_type": "code",
278-
"execution_count": null,
334+
"execution_count": 11,
279335
"metadata": {},
280-
"outputs": [],
336+
"outputs": [
337+
{
338+
"name": "stdout",
339+
"output_type": "stream",
340+
"text": [
341+
"==================================================\n",
342+
"Wall time init: 0.026\n",
343+
"Wall time medusa: 0.031\n",
344+
"Wall time Tree: 3.732\n",
345+
"Wall time Posterior: 0.025\n",
346+
"Wall time Update: 0.051\n",
347+
"--------------------------------------------------\n",
348+
"Wall time portion medusa: 0.008\n",
349+
"Wall time portion Tree: 0.965\n",
350+
"Wall time portion Posterior: 0.007\n",
351+
"Wall time portion Update: 0.013\n",
352+
"--------------------------------------------------\n",
353+
"Tokens/second: 104.247\n",
354+
"==================================================\n"
355+
]
356+
}
357+
],
281358
"source": [
282359
"max_length = 50\n",
283360
"\n",
@@ -307,13 +384,6 @@
307384
"print(format_string(\"Tokens/second: \", new_token / time_total, max_length))\n",
308385
"print('='*max_length)"
309386
]
310-
},
311-
{
312-
"cell_type": "code",
313-
"execution_count": null,
314-
"metadata": {},
315-
"outputs": [],
316-
"source": []
317387
}
318388
],
319389
"metadata": {

notebooks/medusa_introduction.ipynb

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,17 @@
6666
"execution_count": 2,
6767
"metadata": {},
6868
"outputs": [
69+
{
70+
"name": "stdout",
71+
"output_type": "stream",
72+
"text": [
73+
"Overriding medusa_num_heads as: 4\n"
74+
]
75+
},
6976
{
7077
"data": {
7178
"application/vnd.jupyter.widget-view+json": {
72-
"model_id": "b6655fe35f99442e9c250f18882ce883",
79+
"model_id": "f4d7c1aaf692402f959920bc9ccd8593",
7380
"version_major": 2,
7481
"version_minor": 0
7582
},
@@ -234,14 +241,14 @@
234241
"name": "stdout",
235242
"output_type": "stream",
236243
"text": [
237-
"['a']\n"
244+
"['a']\n",
245+
"['time']\n"
238246
]
239247
},
240248
{
241249
"name": "stdout",
242250
"output_type": "stream",
243251
"text": [
244-
"['time']\n",
245252
"[',']\n",
246253
"['in']\n",
247254
"['a']\n",
@@ -917,7 +924,13 @@
917924
"name": "stdout",
918925
"output_type": "stream",
919926
"text": [
920-
"Prediction @ 1: ['Once']\n",
927+
"Prediction @ 1: ['Once']\n"
928+
]
929+
},
930+
{
931+
"name": "stdout",
932+
"output_type": "stream",
933+
"text": [
921934
"Prediction @ 2: ['upon', 'a']\n",
922935
"Prediction @ 3: ['time', ',', 'in']\n",
923936
"Prediction @ 4: ['a', 'small']\n",
@@ -1583,7 +1596,7 @@
15831596
},
15841597
{
15851598
"cell_type": "code",
1586-
"execution_count": 28,
1599+
"execution_count": 29,
15871600
"metadata": {},
15881601
"outputs": [
15891602
{

0 commit comments

Comments
 (0)