From a2b068186bba287c0295413e89700ebf4cc98f66 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 17 Sep 2025 20:42:22 +0000 Subject: [PATCH] update eagle example notebook Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/example.ipynb | 456 +++++++++++++------- 1 file changed, 302 insertions(+), 154 deletions(-) diff --git a/examples/speculative_decoding/example.ipynb b/examples/speculative_decoding/example.ipynb index 4278f0e2..e9a84a05 100644 --- a/examples/speculative_decoding/example.ipynb +++ b/examples/speculative_decoding/example.ipynb @@ -4,33 +4,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Synthesize data for speculative decoding training\n", - "\n", - "The speculative decoding medule needs to learn to predict tokens from the base model. Therefore, we need to prepare the data generated from the base model.\n", - "Note: if the target base model is a quantized version, the synthesized data should be generated using the quantized model." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, quantize the base model (Llama-3.2-1B-Instruct) into FP8 and export to unified export format." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!python llm_ptq/hf_ptq.py --pyt_ckpt_path meta-llama/Llama-3.2-1B-Instruct --qformat fp8 --batch_size 1 --export_path /tmp/llama3.2_1B_fp8 --export_fmt hf" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, download the Daring-Anteater dataset." + "## Prepare Data\n", + "In this example, we use the Daring-Anteater dataset. For improved accuracy, please refer to the [Data Synthesis Section](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding#optional-data-synthesis) in the README." ] }, { @@ -46,41 +21,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Then, launch an inference server that will run the quantized base model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!vllm serve /tmp/llama3.2_1B_fp8 --api-key token-abc123 --port 8000 --tensor-parallel-size 1 --quantization=modelopt" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Open a new terminal and adapt the fine-tuning data by calling this server.\n", - "Note: this may take a long time." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!mkdir /tmp/finetune\n", - "!bash prepare_data.sh --data_path /tmp/Daring-Anteater/train.jsonl --output_path /tmp/finetune/data.jsonl --max_token 2048" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's load the base model and convert it to EAGLE Model" + "## Convert Model for Speculative Decoding\n", + "Here, we'll adapt our base model for speculative decoding by attaching a smaller EAGLE module. The upcoming code first loads meta-llama/Llama-3.2-1B as our base model and then configures the new draft module. To ensure compatibility, the draft module's dimensions must match the target model. Finally, the modelopt toolkit attaches this new, untrained module, leaving us with a combined model that is ready for the training phase later." ] }, { @@ -93,28 +35,49 @@ "\n", "import modelopt.torch.opt as mto\n", "import modelopt.torch.speculative as mtsp\n", + "from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG\n", "\n", "mto.enable_huggingface_checkpointing()\n", "\n", + "# Load original HF model\n", + "base_model = \"meta-llama/Llama-3.2-1B\"\n", "model = transformers.AutoModelForCausalLM.from_pretrained(\n", - " \"meta-llama/Llama-3.2-1B-Instruct\", torch_dtype=\"auto\"\n", + " base_model, torch_dtype=\"auto\", device_map=\"cuda\"\n", ")\n", - "config = {\n", - " \"eagle_num_layers\": 1,\n", - " \"use_input_layernorm_in_first_layer\": True,\n", - " \"use_last_layernorm\": False,\n", - "}\n", + "\n", + "# Read Default Config for EAGLE3\n", + "config = EAGLE3_DEFAULT_CFG[\"config\"]\n", + "\n", + "# Hidden size and vocab size must match base model\n", + "config[\"eagle_architecture_config\"].update(\n", + " {\n", + " \"hidden_size\": model.config.hidden_size,\n", + " \"vocab_size\": model.config.vocab_size,\n", + " \"draft_vocab_size\": model.config.vocab_size,\n", + " \"max_position_embeddings\": model.config.max_position_embeddings,\n", + " }\n", + ")\n", + "\n", + "# Convert Model for eagle speculative decoding\n", "mtsp.convert(model, [(\"eagle\", config)])\n", "\n", - "tokenizer = transformers.AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n", - "tokenizer.pad_token_id = tokenizer.eos_token_id" + "# Prepare Tokenizer\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)\n", + "tokenizer.pad_token_id = tokenizer.eos_token_id\n", + "if tokenizer.chat_template is None:\n", + " tokenizer.chat_template = (\n", + " \"{%- for message in messages %}\"\n", + " \"{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}\"\n", + " \"{%- endfor %}\"\n", + " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Once synthesized data is ready, we can start training the eagle model." + "## Train Draft module On Daring-Anteater\n", + "We will fine-tune the draft module on the Daring-Anteater dataset using the standard Hugging Face Trainer. Note that only the draft module's weights are updated during this process; the original target model remains frozen. After training, our speculative decoding model will be ready for export and deployment. Note that the time to train will be significantly dependent on the epochs (default=4) and the hardware being used." ] }, { @@ -126,10 +89,10 @@ "import json\n", "from dataclasses import dataclass, field\n", "\n", - "from speculative_decoding.eagle_utils import DataCollatorWithPadding, LazySupervisedDataset\n", + "from eagle_utils import DataCollatorWithPadding, LazySupervisedDataset\n", "from transformers import Trainer\n", "\n", - "with open(\"/tmp/finetune/data.jsonl\") as f:\n", + "with open(\"/tmp/Daring-Anteater/train.jsonl\") as f:\n", " data_json = [json.loads(line) for line in f]\n", "train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)\n", "eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)\n", @@ -137,22 +100,13 @@ "\n", "@dataclass\n", "class TrainingArguments(transformers.TrainingArguments):\n", - " cache_dir: str | None = field(default=None)\n", - " model_max_length: int = field(\n", - " default=4096,\n", - " metadata={\n", - " \"help\": (\n", - " \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n", - " )\n", - " },\n", - " )\n", " dataloader_drop_last: bool = field(default=True)\n", " bf16: bool = field(default=True)\n", "\n", "\n", "training_args = TrainingArguments(\n", " output_dir=\"/tmp/eagle_bf16\",\n", - " num_train_epochs=1.0,\n", + " num_train_epochs=4,\n", " per_device_train_batch_size=1,\n", " per_device_eval_batch_size=1,\n", ")\n", @@ -166,25 +120,47 @@ ")\n", "trainer._move_model_to_device(model, trainer.args.device)\n", "\n", - "# Manually enable this to return loss in eval\n", - "trainer.can_return_loss = True\n", "# Make sure label_smoother is None\n", "assert trainer.label_smoother is None, \"label_smoother is not supported in speculative decoding!\"\n", "\n", "trainer.train()\n", "trainer.save_state()\n", "trainer.save_model(training_args.output_dir)\n", - "tokenizer.save_pretrained(training_args.output_dir)\n", + "tokenizer.save_pretrained(training_args.output_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Export Model Checkpoint\n", + "To deploy this model, we need to first export it to a Unified checkpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from modelopt.torch.export import export_hf_checkpoint\n", "\n", - "metrics = trainer.evaluate()\n", - "print(f\"Evaluation results: \\n{metrics}\")" + "model.eval()\n", + "export_hf_checkpoint(\n", + " model,\n", + " export_dir=\"/tmp/hf_ckpt\",\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now we have a EAGLE model in BF16 format. Next, we quantize this model into FP8 (PTQ)." + "## Deploying on TensorRT-LLM\n", + "\n", + "Here we show an example to deploy on TRT-LLM with `trtllm-serve` and [TRT-LLM container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release). See [Deployment](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding#deployment) section for more info. \n", + "\n", + "First, we dump the `trtllm-serve` command and config file we need to `/tmp` folder." ] }, { @@ -193,42 +169,45 @@ "metadata": {}, "outputs": [], "source": [ - "import modelopt.torch.quantization as mtq\n", - "import modelopt.torch.utils.dataset_utils as dataset_utils\n", + "trtllm_serve_script = f\"\"\"trtllm-serve {base_model} \\\\\n", + " --host 0.0.0.0 \\\\\n", + " --port 8000 \\\\\n", + " --backend pytorch \\\\\n", + " --max_batch_size 32 \\\\\n", + " --max_num_tokens 8192 \\\\\n", + " --max_seq_len 8192 \\\\\n", + " --extra_llm_api_options /tmp/extra-llm-api-config.yml\n", + "\"\"\"\n", "\n", - "mto.enable_huggingface_checkpointing()\n", + "extra_llm_api_config = \"\"\"enable_attention_dp: false\n", + "disable_overlap_scheduler: true\n", + "enable_autotuner: false\n", "\n", - "model = transformers.AutoModelForCausalLM.from_pretrained(\"/tmp/eagle_bf16\")\n", - "tokenizer = transformers.AutoTokenizer.from_pretrained(\"/tmp/eagle_bf16\")\n", + "cuda_graph_config:\n", + " max_batch_size: 1\n", "\n", - "calib_dataloader = dataset_utils.get_dataset_dataloader(\n", - " dataset_name=\"cnn_dailymail\",\n", - " tokenizer=tokenizer,\n", - " batch_size=1,\n", - " num_samples=512,\n", - " device=model.device,\n", - " include_labels=False,\n", - ")\n", + "speculative_config:\n", + " decoding_type: Eagle\n", + " max_draft_len: 3\n", + " speculative_model_dir: /tmp/hf_ckpt\n", "\n", - "quant_cfg = getattr(mtq, \"FP8_DEFAULT_CFG\")\n", - "quant_cfg[\"quant_cfg\"][\"*output_quantizer\"] = {\n", - " \"num_bits\": (4, 3),\n", - " \"axis\": None,\n", - " \"enable\": True,\n", - "}\n", + "kv_cache_config:\n", + " enable_block_reuse: false\n", + "\"\"\"\n", "\n", - "calibrate_loop = dataset_utils.create_forward_loop(calib_dataloader, dataloader=calib_dataloader)\n", - "model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)\n", - "mtq.print_quant_summary(model)\n", + "# Dump the two scripts into /tmp\n", + "with open(\"/tmp/trtllm_serve.sh\", \"w\") as f:\n", + " f.write(trtllm_serve_script)\n", "\n", - "model.save_pretrained(\"/tmp/eagle_fp8_ptq\")" + "with open(\"/tmp/extra-llm-api-config.yml\", \"w\") as f:\n", + " f.write(extra_llm_api_config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To maintain the accuracy, we need to finetune the model (QAT)." + "Next, we start a TRT-LLM container in the background and run `trtllm-serve` inside it, using our exported checkpoint and the configuration scripts we just created:" ] }, { @@ -237,36 +216,62 @@ "metadata": {}, "outputs": [], "source": [ - "training_args.output_dir = \"/tmp/eagle_fp8_qat\"\n", - "trainer = Trainer(\n", - " model=model,\n", - " tokenizer=tokenizer,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " eval_dataset=eval_dataset,\n", - " data_collator=DataCollatorWithPadding(),\n", + "import subprocess\n", + "import threading\n", + "\n", + "# Generate a unique container name so we can stop/remove it later\n", + "container_name = \"trtllm_serve_spec\"\n", + "\n", + "docker_cmd = [\n", + " \"docker\",\n", + " \"run\",\n", + " \"--rm\",\n", + " \"--net\",\n", + " \"host\",\n", + " \"--shm-size=2g\",\n", + " \"--ulimit\",\n", + " \"memlock=-1\",\n", + " \"--ulimit\",\n", + " \"stack=67108864\",\n", + " \"--gpus\",\n", + " \"all\",\n", + " \"-v\",\n", + " \"/tmp:/tmp\",\n", + " \"--name\",\n", + " container_name,\n", + " \"nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2\",\n", + " \"bash\",\n", + " \"-c\",\n", + " \"bash /tmp/trtllm_serve.sh\",\n", + "]\n", + "\n", + "# print docker outputs\n", + "proc = subprocess.Popen(\n", + " docker_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1\n", ")\n", - "trainer._move_model_to_device(model, trainer.args.device)\n", "\n", - "# Manually enable this to return loss in eval\n", - "trainer.can_return_loss = True\n", - "# Make sure label_smoother is None\n", - "assert trainer.label_smoother is None, \"label_smoother is not supported in speculative decoding!\"\n", "\n", - "trainer.train()\n", - "trainer.save_state()\n", - "trainer.save_model(training_args.output_dir)\n", - "tokenizer.save_pretrained(training_args.output_dir)\n", + "def stream_output(pipe):\n", + " for line in iter(pipe.readline, \"\"):\n", + " print(line, end=\"\")\n", + "\n", + "\n", + "# Use thread to print outputs\n", + "thread = threading.Thread(target=stream_output, args=(proc.stdout,))\n", + "thread.daemon = True\n", + "thread.start()\n", "\n", - "metrics = trainer.evaluate()\n", - "print(f\"Evaluation results: \\n{metrics}\")" + "print(\n", + " f\"Starting trtllm-serve in Docker (PID: {proc.pid}, container name: {container_name}) in the background:\"\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To deploy this model, we need to first export it to a Unified checkpoint." + "Please wait for the service to fully start inside the container. \n", + "Once you see the message `INFO: Application startup complete.`, you can proceed to send requests to the service:" ] }, { @@ -275,24 +280,84 @@ "metadata": {}, "outputs": [], "source": [ - "from accelerate.hooks import remove_hook_from_module\n", + "import json\n", "\n", - "from modelopt.torch.export import export_hf_checkpoint\n", + "import requests\n", "\n", - "# Move meta tensor back to device before exporting.\n", - "remove_hook_from_module(model, recurse=True)\n", + "payload = {\n", + " \"model\": base_model,\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n", + " ],\n", + " \"max_tokens\": 512,\n", + " \"temperature\": 0,\n", + " \"chat_template\": tokenizer.chat_template,\n", + "}\n", + "headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n", "\n", - "export_hf_checkpoint(\n", - " model,\n", - " export_dir=\"/tmp/hf_ckpt\",\n", - ")" + "response = requests.post(\n", + " \"http://localhost:8000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n", + ")\n", + "output = response.json()\n", + "\n", + "print(output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we clean up the container we created." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!docker rm -f trtllm_serve_spec" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deploying on SGLang\n", + "Here, we deploy our trained model using SGLang. The following code defines the command needed to run the SGLang server with our specific configuration for speculative decoding." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# SGLang server launch command shell script\n", + "sglang_serve_script = f\"\"\"python3 -m sglang.launch_server \\\\\n", + " --model {base_model} \\\\\n", + " --host 0.0.0.0 \\\\\n", + " --port 30000 \\\\\n", + " --speculative-algorithm EAGLE3 \\\\\n", + " --speculative-eagle-topk 8 \\\\\n", + " --speculative-draft-model-path /tmp/hf_ckpt \\\\\n", + " --speculative-num-draft-tokens 3 \\\\\n", + " --speculative-num-steps 3 \\\\\n", + " --mem-fraction 0.6 \\\\\n", + " --cuda-graph-max-bs 2 \\\\\n", + " --dtype float16\n", + "\"\"\"\n", + "\n", + "with open(\"/tmp/sglang_serve.sh\", \"w\") as f:\n", + " f.write(sglang_serve_script)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Then convert the Unified ckeckpoint to TRTLLM checkpoint." + "Launch the SGLang server inside a Docker container as a background process." ] }, { @@ -301,14 +366,67 @@ "metadata": {}, "outputs": [], "source": [ - "!python TensorRT-LLM/examples/eagle/convert_checkpoint.py --model_dir /tmp/hf_ckpt --output_dir /tmp/trtllm_ckpt --num_eagle_layers 5 --max_non_leaves_per_layer 4 --max_draft_len 25 --dtype float16" + "import os\n", + "import subprocess\n", + "import threading\n", + "\n", + "container_name = \"sglang_serve_spec\"\n", + "home_dir = os.path.expanduser(\"~\")\n", + "hf_cache_dir = os.path.join(home_dir, \".cache\", \"huggingface\")\n", + "\n", + "# Ensure the Hugging Face cache directory exists. This directory should exist as ~/.cache/huggingface, when the model files for meta-llama/Llama-3.2-1B were downloaded earlier.\n", + "os.makedirs(hf_cache_dir, exist_ok=True)\n", + "\n", + "docker_cmd = [\n", + " \"docker\",\n", + " \"run\",\n", + " \"--rm\",\n", + " \"--net\",\n", + " \"host\",\n", + " \"--shm-size=32g\",\n", + " \"--gpus\",\n", + " \"all\",\n", + " \"-v\",\n", + " f\"{hf_cache_dir}:/root/.cache/huggingface\",\n", + " \"-v\",\n", + " \"/tmp:/tmp\",\n", + " \"--ipc=host\",\n", + " \"--name\",\n", + " container_name,\n", + " \"lmsysorg/sglang:latest\",\n", + " \"bash\",\n", + " \"-c\",\n", + " \"bash /tmp/sglang_serve.sh\",\n", + "]\n", + "\n", + "# Launch the Docker container\n", + "proc = subprocess.Popen(\n", + " docker_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1\n", + ")\n", + "\n", + "\n", + "# Stream the process output\n", + "def stream_output(pipe):\n", + " for line in iter(pipe.readline, \"\"):\n", + " print(line, end=\"\")\n", + "\n", + "\n", + "# Use a thread to stream the output in without blocking the notebook\n", + "thread = threading.Thread(target=stream_output, args=(proc.stdout,))\n", + "thread.daemon = True\n", + "thread.start()\n", + "\n", + "print(\n", + " f\"Starting SGLang server in Docker (PID: {proc.pid}, container name: {container_name}) in the background:\"\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Last, build a TensorRT-LLM engine." + "As with TRT-LLM, please wait for the service to fully start inside the container. \n", + "Once you see the message `INFO: Application startup complete.`, you can proceed to send requests to the service:" ] }, { @@ -317,14 +435,35 @@ "metadata": {}, "outputs": [], "source": [ - "!trtllm-build --checkpoint_dir /tmp/trtllm_ckpt --output_dir /tmp/trtllm_engine --gemm_plugin float16 --use_paged_context_fmha enable --speculative_decoding_mode eagle --max_batch_size 4" + "import json\n", + "\n", + "import requests\n", + "\n", + "payload = {\n", + " \"model\": base_model,\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n", + " ],\n", + " \"max_tokens\": 512,\n", + " \"temperature\": 0,\n", + "}\n", + "headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n", + "\n", + "# Send request to the SGLang server\n", + "response = requests.post(\n", + " \"http://localhost:30000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n", + ")\n", + "output = response.json()\n", + "\n", + "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To run the EAGLE engine, please refer to [TensorRT-LLM/examples/eagle](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/eagle):" + "Clean up the container" ] }, { @@ -333,18 +472,27 @@ "metadata": {}, "outputs": [], "source": [ - "!python ../run.py --engine_dir /tmp/trtllm_engine \\\n", - " --tokenizer_dir /tmp/eagle_fp8_qat \\\n", - " --max_output_len=100 \\\n", - " --eagle_choices=\"[[0],[1],[2],[3],[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0],[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,2,0],[0,2,1],[1,0,0],[0,0,0,0],[0,0,0,1],[0,0,0,2],[0,0,0,0,0],[0,0,0,0,1]]\" \\\n", - " --temperature 1.0 \\\n", - " --input_text \"Once upon\"" + "!docker rm -f sglang_serve_spec" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deploying on vLLM (Coming Soon)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While vLLM is another extremely popular, high-performance inference server, direct support for speculative decoding with this demo notebook is still under active development. This notebook will be updated once deployment is possible." ] } ], "metadata": { "kernelspec": { - "display_name": "py312", + "display_name": "modelopt+vllm", "language": "python", "name": "python3" }, @@ -358,7 +506,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.0" } }, "nbformat": 4,