Skip to content

Commit 293d659

Browse files
committed
merge notebook updates from Jamie
Signed-off-by: h-guo18 <[email protected]>
1 parent e7a98f7 commit 293d659

File tree

1 file changed

+176
-7
lines changed

1 file changed

+176
-7
lines changed

examples/speculative_decoding/example.ipynb

Lines changed: 176 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
"cell_type": "markdown",
2222
"metadata": {},
2323
"source": [
24-
"## Convert Model\n",
25-
"Let's load the base model and convert it to EAGLE3 Model"
24+
"## Convert Model for Speculative Decoding\n",
25+
"Here, we'll adapt our base model for speculative decoding by attaching a smaller EAGLE Head. The upcoming code first loads meta-llama/Llama-3.2-1B as our base model and then configures the new draft head. To ensure compatibility, the draft head's dimensions must match the target model. Finally, the modelopt toolkit attaches this new, untrained head, leaving us with a combined model that is ready for the training phase later."
2626
]
2727
},
2828
{
@@ -76,8 +76,8 @@
7676
"cell_type": "markdown",
7777
"metadata": {},
7878
"source": [
79-
"## Train Draft Model On Daring-Anteater\n",
80-
"Then we can start training the eagle model with HF trainer."
79+
"## Train Draft Head On Daring-Anteater\n",
80+
"We will fine-tune the draft head on the Daring-Anteater dataset using the standard Hugging Face Trainer. Note that only the draft model's weights are updated during this process; the original target model remains unchanged. After training, our speculative decoding model will be ready for export and deployment. Note that the time to train will be significantly dependent on the epochs (default=4) and the hardware being used."
8181
]
8282
},
8383
{
@@ -106,7 +106,7 @@
106106
"\n",
107107
"training_args = TrainingArguments(\n",
108108
" output_dir=\"/tmp/eagle_bf16\",\n",
109-
" num_train_epochs=2,\n",
109+
" num_train_epochs=4,\n",
110110
" per_device_train_batch_size=1,\n",
111111
" per_device_eval_batch_size=1,\n",
112112
")\n",
@@ -156,7 +156,7 @@
156156
"cell_type": "markdown",
157157
"metadata": {},
158158
"source": [
159-
"## Deployment\n",
159+
"## Deploying on TensorRT-LLM\n",
160160
"\n",
161161
"Here we show an example to deploy on TRT-LLM with `trtllm-serve` and [TRT-LLM container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release). See [Deployment](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding#deployment) section for more info. \n",
162162
"\n",
@@ -288,7 +288,7 @@
288288
" \"model\": base_model,\n",
289289
" \"messages\": [\n",
290290
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
291-
" {\"role\": \"user\", \"content\": \"Hi, write me a story about a cat\"},\n",
291+
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n",
292292
" ],\n",
293293
" \"max_tokens\": 512,\n",
294294
" \"temperature\": 0,\n",
@@ -319,6 +319,175 @@
319319
"source": [
320320
"!docker rm -f trtllm_serve_spec"
321321
]
322+
},
323+
{
324+
"cell_type": "markdown",
325+
"metadata": {},
326+
"source": [
327+
"## Deploying on SGLang\n",
328+
"Here, we deploy our trained model using SGLang. The following code defines the command needed to run the SGLang server with our specific configuration for speculative decoding."
329+
]
330+
},
331+
{
332+
"cell_type": "code",
333+
"execution_count": null,
334+
"metadata": {},
335+
"outputs": [],
336+
"source": [
337+
"# SGLang server launch command shell script\n",
338+
"sglang_serve_script = f\"\"\"python3 -m sglang.launch_server \\\\\n",
339+
" --model {base_model} \\\\\n",
340+
" --host 0.0.0.0 \\\\\n",
341+
" --port 30000 \\\\\n",
342+
" --speculative-algorithm EAGLE3 \\\\\n",
343+
" --speculative-eagle-topk 8 \\\\\n",
344+
" --speculative-draft-model-path /tmp/hf_ckpt \\\\\n",
345+
" --speculative-num-draft-tokens 3 \\\\\n",
346+
" --speculative-num-steps 3 \\\\\n",
347+
" --mem-fraction 0.6 \\\\\n",
348+
" --cuda-graph-max-bs 2 \\\\\n",
349+
" --dtype float16\n",
350+
"\"\"\"\n",
351+
"\n",
352+
"with open(\"/tmp/sglang_serve.sh\", \"w\") as f:\n",
353+
" f.write(sglang_serve_script)"
354+
]
355+
},
356+
{
357+
"cell_type": "markdown",
358+
"metadata": {},
359+
"source": [
360+
"Launch the SGLang server inside a Docker container as a background process."
361+
]
362+
},
363+
{
364+
"cell_type": "code",
365+
"execution_count": null,
366+
"metadata": {},
367+
"outputs": [],
368+
"source": [
369+
"import os\n",
370+
"import subprocess\n",
371+
"import threading\n",
372+
"\n",
373+
"container_name = \"sglang_serve_spec\"\n",
374+
"home_dir = os.path.expanduser(\"~\")\n",
375+
"hf_cache_dir = os.path.join(home_dir, \".cache\", \"huggingface\")\n",
376+
"\n",
377+
"# Ensure the Hugging Face cache directory exists. This directory should exist as ~/.cache/huggingface, when the model files for meta-llama/Llama-3.2-1B were downloaded earlier.\n",
378+
"os.makedirs(hf_cache_dir, exist_ok=True)\n",
379+
"\n",
380+
"docker_cmd = [\n",
381+
" \"docker\",\n",
382+
" \"run\",\n",
383+
" \"--rm\",\n",
384+
" \"--net\",\n",
385+
" \"host\",\n",
386+
" \"--shm-size=32g\",\n",
387+
" \"--gpus\",\n",
388+
" \"all\",\n",
389+
" \"-v\",\n",
390+
" f\"{hf_cache_dir}:/root/.cache/huggingface\",\n",
391+
" \"-v\",\n",
392+
" \"/tmp:/tmp\",\n",
393+
" \"--ipc=host\",\n",
394+
" \"--name\",\n",
395+
" container_name,\n",
396+
" \"lmsysorg/sglang:latest\",\n",
397+
" \"bash\",\n",
398+
" \"-c\",\n",
399+
" \"bash /tmp/sglang_serve.sh\",\n",
400+
"]\n",
401+
"\n",
402+
"# Launch the Docker container\n",
403+
"proc = subprocess.Popen(\n",
404+
" docker_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1\n",
405+
")\n",
406+
"\n",
407+
"\n",
408+
"# Stream the process output\n",
409+
"def stream_output(pipe):\n",
410+
" for line in iter(pipe.readline, \"\"):\n",
411+
" print(line, end=\"\")\n",
412+
"\n",
413+
"\n",
414+
"# Use a thread to stream the output in without blocking the notebook\n",
415+
"thread = threading.Thread(target=stream_output, args=(proc.stdout,))\n",
416+
"thread.daemon = True\n",
417+
"thread.start()\n",
418+
"\n",
419+
"print(\n",
420+
" f\"Starting SGLang server in Docker (PID: {proc.pid}, container name: {container_name}) in the background:\"\n",
421+
")"
422+
]
423+
},
424+
{
425+
"cell_type": "markdown",
426+
"metadata": {},
427+
"source": [
428+
"As with TRT-LLM, please wait for the service to fully start inside the container. \n",
429+
"Once you see the message `INFO: Application startup complete.`, you can proceed to send requests to the service:"
430+
]
431+
},
432+
{
433+
"cell_type": "code",
434+
"execution_count": null,
435+
"metadata": {},
436+
"outputs": [],
437+
"source": [
438+
"import json\n",
439+
"\n",
440+
"import requests\n",
441+
"\n",
442+
"payload = {\n",
443+
" \"model\": base_model,\n",
444+
" \"messages\": [\n",
445+
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
446+
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n",
447+
" ],\n",
448+
" \"max_tokens\": 512,\n",
449+
" \"temperature\": 0,\n",
450+
"}\n",
451+
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n",
452+
"\n",
453+
"# Send request to the SGLang server\n",
454+
"response = requests.post(\n",
455+
" \"http://localhost:30000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n",
456+
")\n",
457+
"output = response.json()\n",
458+
"\n",
459+
"print(output)"
460+
]
461+
},
462+
{
463+
"cell_type": "markdown",
464+
"metadata": {},
465+
"source": [
466+
"Clean up the container"
467+
]
468+
},
469+
{
470+
"cell_type": "code",
471+
"execution_count": null,
472+
"metadata": {},
473+
"outputs": [],
474+
"source": [
475+
"!docker rm -f sglang_serve_spec"
476+
]
477+
},
478+
{
479+
"cell_type": "markdown",
480+
"metadata": {},
481+
"source": [
482+
"## Deploying on vLLM (Coming Soon)"
483+
]
484+
},
485+
{
486+
"cell_type": "markdown",
487+
"metadata": {},
488+
"source": [
489+
"While vLLM is another extremely popular, high-performance inference server, direct support for speculative decoding with this demo notebook is still under active development. This notebook will be updated once deployment is possible."
490+
]
322491
}
323492
],
324493
"metadata": {

0 commit comments

Comments
 (0)