Skip to content

Conversation

h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Sep 10, 2025

What does this PR do?

Type of change: new example

Overview: This PR update existing notebook example of speculative decoding for the purpose of blog post. Main changes include:

  • For the purpose of export compatibility, temporarily removed PTQ example;
  • Added serving example in notebook through TRTLLM/SGlang docker image.

Usage

See instruction in example.ipynb

Testing

Tested running the notebook with Llama3.2-1B;

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • End-to-end speculative-decoding workflow: model adaptation, tokenizer alignment, training a draft module, export to a unified HF checkpoint, and containerized deployment.
    • One-click Docker/serve scripts for TRT-LLM and SGLang with background start, log streaming, and a sample test request.
    • Placeholder note for upcoming vLLM support.
  • Refactor

    • Notebook reorganized into clear stages: data, training (HF Trainer-based), export, and deployment.
  • Chores

    • Updated notebook kernel display name and Python version.

Copy link

copy-pr-bot bot commented Sep 10, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 10, 2025

Warning

Rate limit exceeded

@h-guo18 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 8 minutes and 40 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 7035565 and a2b0681.

📒 Files selected for processing (1)
  • examples/speculative_decoding/example.ipynb (12 hunks)

Walkthrough

The notebook examples/speculative_decoding/example.ipynb was rewritten into an end-to-end speculative-decoding pipeline: load and adapt meta-llama/Llama-3.2-1B into an EAGLE3 draft/head, train the draft with HF Trainer, export a unified HF checkpoint, and add container-based deployment scripts for TRT-LLM and SGLang; TrainingArguments fields removed.

Changes

Cohort / File(s) Summary of Changes
Speculative decoding pipeline overhaul
examples/speculative_decoding/example.ipynb
Replaced prior data-synthesis/FP8 focus with a full speculative-decoding flow: load meta-llama/Llama-3.2-1B, configure EAGLE3 to match model dims (hidden_size, vocab/draft_vocab, max_position_embeddings), prepare tokenizer (model_max_length=1024, pad_token_id=eos_token_id, chat_template fallback), adapt model for EAGLE3.
Training flow & API change (notebook-local)
examples/speculative_decoding/example.ipynb
Added HF Trainer-based training of the draft EAGLE module (DataCollatorWithPadding, device placement, 4 epochs, label_smoother=None), dataset load/split from /tmp/Daring-Anteater/train.jsonl, save tokenizer/model. Removed cache_dir and model_max_length from the notebook's TrainingArguments dataclass.
Export step
examples/speculative_decoding/example.ipynb
Added export to a unified HF checkpoint at /tmp/hf_ckpt via export_hf_checkpoint.
Deployment scaffolding (TRT-LLM, SGLang)
examples/speculative_decoding/example.ipynb
Added generation of /tmp/trtllm_serve.sh (speculative_config, kv_cache_config), background container launch, log streaming, smoke-test POST; added /tmp/sglang_serve.sh and Docker-based SGLang speculative server; noted vLLM as coming soon.
Notebook metadata & reorg
examples/speculative_decoding/example.ipynb
Kernel display name set to modelopt+vllm; Python version updated to 3.12.0; notebook reorganized into explicit sections for model adaptation, training, export, and deployment.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor U as User
  participant NB as Notebook
  participant HF as HF Model/Tokenizer
  participant DS as Daring-Anteater Dataset
  participant TR as HF Trainer
  participant CK as HF Checkpoint (/tmp/hf_ckpt)

  U->>NB: Run notebook
  NB->>HF: Load Llama-3.2-1B + tokenizer
  NB->>HF: Adapt model for EAGLE3 (draft head, config sync)
  NB->>DS: Load and split dataset (train/eval)
  NB->>TR: Configure Trainer (4 epochs, no label_smoother)
  TR->>HF: Train draft head
  TR-->>NB: Save model/tokenizer
  NB->>CK: Export unified HF checkpoint
  note over NB,CK: Export replaces prior quantization flow
Loading
sequenceDiagram
  autonumber
  actor U as User
  participant NB as Notebook
  participant TK as /tmp/trtllm_serve.sh
  participant SK as /tmp/sglang_serve.sh
  participant T as TRT-LLM Server (Docker)
  participant S as SGLang Server (Docker)

  U->>NB: Execute deployment cells
  NB->>TK: Generate TRT-LLM script with speculative_config & kv_cache_config
  NB->>T: Launch container (background), stream logs
  NB->>T: Send test inference request
  NB->>SK: Generate SGLang serve script (speculative settings)
  NB->>S: Launch SGLang container
  note over T,S: vLLM: placeholder (coming soon)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I twitch my whiskers, hop with cheer,
Drafts and tokens hum so near,
EAGLE snug on Llama's frame,
Checkpoints saved and scripts aflame.
Containers launch—little rabbit clap! 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "update eagle example notebook" accurately indicates the PR updates the EAGLE example notebook and aligns with the main changes (reworking the speculative-decoding example, training/export, and deployment). It is concise and a single sentence that relates to the changeset. It is somewhat generic and could be more specific about the substantive updates (speculative-decoding pipeline, training/export, TRT-LLM/SGlang serving).
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-notebook-example branch 2 times, most recently from dfe0b26 to de07831 Compare September 10, 2025 00:32
@h-guo18 h-guo18 self-assigned this Sep 10, 2025
Copy link

codecov bot commented Sep 10, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.82%. Comparing base (d94fc1b) to head (a2b0681).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #314   +/-   ##
=======================================
  Coverage   73.82%   73.82%           
=======================================
  Files         172      172           
  Lines       17438    17438           
=======================================
+ Hits        12873    12874    +1     
+ Misses       4565     4564    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-notebook-example branch from de07831 to 293d659 Compare September 16, 2025 01:31
@h-guo18 h-guo18 marked this pull request as ready for review September 16, 2025 01:39
@h-guo18 h-guo18 requested a review from a team as a code owner September 16, 2025 01:39
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (8)
examples/speculative_decoding/example.ipynb (8)

17-18: Prefer snapshot_download over git clone for HF datasets (avoids git‑lfs issues).

git clone of HF datasets often fails without git‑lfs. Use huggingface_hub.snapshot_download for reliability.

-!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater
+from huggingface_hub import snapshot_download
+snapshot_download(repo_id="nvidia/Daring-Anteater", repo_type="dataset", local_dir="/tmp/Daring-Anteater", local_dir_use_symlinks=False, ignore_patterns=["*.md","*.png",".gitattributes",".git/*"])

65-72: Chat template fallback may not match Llama‑3.x format.

Your fallback uses <|im_start|> which differs from Llama‑3.x templates. Keep only if truly None; otherwise you risk mismatched formatting. Consider loading the template from the HF repo/config instead of hardcoding.


101-112: Make bf16 selection hardware‑aware; add a seed.

Hard‑coding bf16=True fails on GPUs without BF16. Also set a seed for reproducibility.

-from dataclasses import dataclass, field
+from dataclasses import dataclass, field
+import torch
+transformers.set_seed(42)
@@
 class TrainingArguments(transformers.TrainingArguments):
-    dataloader_drop_last: bool = field(default=True)
-    bf16: bool = field(default=True)
+    dataloader_drop_last: bool = field(default=True)
+    bf16: bool = field(default=(torch.cuda.is_available() and torch.cuda.is_bf16_supported()))

129-130: Save tokenizer with the exported checkpoint (used by servers).

You save the tokenizer to the training dir but not to the exported HF ckpt used by serving.

-tokenizer.save_pretrained(training_args.output_dir)
+tokenizer.save_pretrained(training_args.output_dir)
+tokenizer.save_pretrained("/tmp/hf_ckpt")  # keep alongside exported model for serving

172-180: Derive server max lengths from model config; bind to localhost.

Avoid hard‑coding 8192 and exposing on 0.0.0.0. Use model.config.max_position_embeddings and bind to 127.0.0.1 for local demos.

-context_len = 8192
+context_len = int(getattr(model.config, "max_position_embeddings", 8192))

-trtllm_serve_script = f"""trtllm-serve {base_model} \
-    --host 0.0.0.0 \
+trtllm_serve_script = f"""trtllm-serve {base_model} \
+    --host 127.0.0.1 \
     --port 8000 \
@@
-    --max_num_tokens 8192 \
-    --max_seq_len 8192 \
+    --max_num_tokens {context_len} \
+    --max_seq_len {context_len} \
     --extra_llm_api_options /tmp/extra-llm-api-config.yml
 """

283-305: Harden the client request (timeouts + errors).

Add a short timeout and raise for non‑200 responses to avoid hanging cells.

-response = requests.post(
-    "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
+response = requests.post(
+    "http://127.0.0.1:8000/v1/chat/completions",
+    headers=headers,
+    data=json.dumps(payload),
+    timeout=30,
+)
+response.raise_for_status()
+output = response.json()

337-350: Bind SGLang server to localhost; consider pinning dtype to bf16 if trained as bf16.

Match serving dtype to training where possible, and avoid exposing 0.0.0.0 for local demos.

-sglang_serve_script = f"""python3 -m sglang.launch_server \
-    --model {base_model} \
-    --host 0.0.0.0 \
+sglang_serve_script = f"""python3 -m sglang.launch_server \
+    --model {base_model} \
+    --host 127.0.0.1 \
@@
-    --dtype float16
+    --dtype bfloat16
 """

If BF16 isn’t supported, fallback to float16.


438-460: Add timeout and error checks for SGLang request.

-response = requests.post(
-    "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
+response = requests.post(
+    "http://127.0.0.1:30000/v1/chat/completions",
+    headers=headers,
+    data=json.dumps(payload),
+    timeout=30,
+)
+response.raise_for_status()
+output = response.json()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and 293d659.

📒 Files selected for processing (1)
  • examples/speculative_decoding/example.ipynb (12 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: linux
  • GitHub Check: build-docs
🔇 Additional comments (6)
examples/speculative_decoding/example.ipynb (6)

148-152: LGTM on export path.

Exporting to a unified HF checkpoint is correct and matches the serving steps.


159-164: LGTM on TRT‑LLM deployment narrative.

Clear, usable steps with config dumping before container launch.


327-349: LGTM on SGLang configuration options.

Flags align with EAGLE‑style speculative decoding; minor host/dtype notes addressed above.


189-193: Resolved — use "Eagle" (capital E)

Confirmed: trtllm-serve expects decoding_type = "Eagle"; the example already uses this value, no change required.


380-401: Replace host networking with port mapping; pin Docker image to lmsysorg/sglang:v0.4.9.post4.

Host networking is unsafe; v0.4.9.post4 is the stable tag that includes EAGLE‑3.

File: examples/speculative_decoding/example.ipynb (lines 380–401)

-    "--net",
-    "host",
+    "-p","30000:30000",
@@
-    "lmsysorg/sglang:latest",
+    "lmsysorg/sglang:v0.4.9.post4",

92-92: Import valid — eagle_utils defines the symbols used.

examples/speculative_decoding/eagle_utils.py defines LazySupervisedDataset (line 139) and DataCollatorWithPadding (line 207); the notebook import is valid.

Likely an incorrect or invalid review comment.

Copy link
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM;

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
examples/speculative_decoding/example.ipynb (4)

17-18: Don’t hardcode dataset path or rely on git clone; load via datasets and add a guard.

git-lfs may be missing and /tmp paths vary. Use datasets API and/or parameterize the path with an existence check.

Option A — datasets API (preferred):

-!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater
+from datasets import load_dataset
+ds = load_dataset("nvidia/Daring-Anteater", split="train")
+data_json = [x for x in ds]

Then replace file I/O:

-with open("/tmp/Daring-Anteater/train.jsonl") as f:
-    data_json = [json.loads(line) for line in f]
+# data_json already loaded above; keep downstream code as-is

Option B — parameterize + guard:

+from pathlib import Path
+DATASET_PATH = Path(os.getenv("EAGLE_DATASET_PATH", "/tmp/Daring-Anteater/train.jsonl"))
+if not DATASET_PATH.exists():
+    raise FileNotFoundError(
+        f"Dataset not found at {DATASET_PATH}. "
+        "Set EAGLE_DATASET_PATH or install `datasets` and use load_dataset('nvidia/Daring-Anteater')."
+    )
-with open("/tmp/Daring-Anteater/train.jsonl") as f:
+with open(DATASET_PATH, encoding="utf-8") as f:
     data_json = [json.loads(line) for line in f]

Also applies to: 95-99


44-46: Remove device_map and the private _move_model_to_device; let Trainer handle placement.

Specifying device_map with Trainer and calling a private Trainer method is brittle and can misplace params. Let Trainer place the model (or call model.to(...) once before Trainer).

Apply:

- model = transformers.AutoModelForCausalLM.from_pretrained(
-     base_model, torch_dtype="auto", device_map="cuda"
- )
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+     base_model, torch_dtype="auto"
+ )

And:

- trainer._move_model_to_device(model, trainer.args.device)
+ # Trainer handles device placement.

Also applies to: 121-121


219-246: Avoid --net host; mount HF cache for gated models; pin/verify TRT‑LLM image.

Use explicit port mapping and mount ~/.cache/huggingface so downloads work inside the container. rc image may be outdated; prefer a stable tag.

-import subprocess
-import threading
+import subprocess
+import threading
+import os
@@
-container_name = "trtllm_serve_spec"
+container_name = "trtllm_serve_spec"
+home_dir = os.path.expanduser("~")
+hf_cache_dir = os.path.join(home_dir, ".cache", "huggingface")
+os.makedirs(hf_cache_dir, exist_ok=True)
@@
-    "--net",
-    "host",
+    "-p","8000:8000",
@@
     "-v",
     "/tmp:/tmp",
+    "-v",
+    f"{hf_cache_dir}:/root/.cache/huggingface",
+    "-e","HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
@@
-    "nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2",
+    "nvcr.io/nvidia/tensorrt-llm/release:<stable-tag>",

To confirm the latest stable TRT‑LLM container and trtllm-serve availability:

What is the latest stable version of the NVIDIA TensorRT-LLM release container, and does it include `trtllm-serve` with speculative decoding/EAGLE support?

384-401: Drop --net host and pin SGLang image; map port explicitly.

Host networking is unnecessary here and reduces safety; also avoid the floating latest tag.

-    "--net",
-    "host",
+    "-p","30000:30000",
@@
-    "lmsysorg/sglang:latest",
+    "lmsysorg/sglang:<pinned-version>",
🧹 Nitpick comments (3)
examples/speculative_decoding/example.ipynb (3)

101-112: Gate bf16 by capability; fallback to fp16 for wider compatibility.

Hard‑enabling bf16 can break on many GPUs. Auto‑detect and set fp16 when bf16 is unsupported.

-@dataclass
-class TrainingArguments(transformers.TrainingArguments):
-    dataloader_drop_last: bool = field(default=True)
-    bf16: bool = field(default=True)
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+    dataloader_drop_last: bool = field(default=True)
+    bf16: bool = field(default=(torch.cuda.is_available() and torch.cuda.is_bf16_supported()))
 
 training_args = TrainingArguments(
     output_dir="/tmp/eagle_bf16",
-    num_train_epochs=4,
+    num_train_epochs=4,
     per_device_train_batch_size=1,
     per_device_eval_batch_size=1,
+    fp16=not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
 )

248-267: Optional: poll readiness instead of manual wait.

Poll the HTTP endpoint before sending requests to reduce flakiness.

 print(
     f"Starting trtllm-serve in Docker (PID: {proc.pid}, container name: {container_name}) in the background:"
 )
+import time, requests as _r
+for _ in range(120):
+    try:
+        if _r.get("http://localhost:8000/v1/models", timeout=2).status_code == 200:
+            break
+    except Exception:
+        time.sleep(1)

442-451: Include chat_template in SGLang request for consistency.

You pass chat_template to TRT‑LLM but not to SGLang; add it to ensure identical prompts.

 payload = {
     "model": base_model,
     "messages": [
         {"role": "system", "content": "You are a helpful assistant."},
         {"role": "user", "content": "Tell me about speculative decoding."},
     ],
     "max_tokens": 512,
     "temperature": 0,
+    "chat_template": tokenizer.chat_template,
 }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 293d659 and 83e06df.

📒 Files selected for processing (1)
  • examples/speculative_decoding/example.ipynb (12 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
examples/speculative_decoding/example.ipynb (4)

61-63: Ensure only the EAGLE/draft module is trainable; freeze base model params.

If convert() doesn’t freeze base weights, Trainer will update the full model. Verify and freeze non‑EAGLE params.

Minimal guard (after convert):

 mtsp.convert(model, [("eagle", config)])
+for name, p in model.named_parameters():
+    p.requires_grad = ("eagle" in name)
+trainable = [n for n, p in model.named_parameters() if p.requires_grad]
+print(f"Trainable params: {len(trainable)} -> {trainable[:5]}...")

Please confirm trainable names are limited to the EAGLE module.

Also applies to: 113-120


64-72: Tokenizer prep LGTM.

Pad token tied to EOS and conditional chat_template initialization look good.


148-153: Export step LGTM.

eval() before export and using export_hf_checkpoint to /tmp/hf_ckpt are appropriate.


172-196: Config content looks reasonable for an initial EAGLE trial.

Speculative config, kv-cache, and conservative cuda-graph max bs are sensible defaults for a demo.

If you plan larger batches later, consider revisiting max_batch_size and cuda_graph_config.

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-notebook-example branch from 83e06df to 7035565 Compare September 17, 2025 20:34
Copy link

copy-pr-bot bot commented Sep 17, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-notebook-example branch from 7035565 to a2b0681 Compare September 17, 2025 20:43
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (6)
examples/speculative_decoding/example.ipynb (6)

95-99: Do not hardcode dataset path; add guard or parameterize.

Make path configurable and fail clearly if missing.

Apply:

-import json
-from dataclasses import dataclass, field
-
-from eagle_utils import DataCollatorWithPadding, LazySupervisedDataset
-from transformers import Trainer
-
-with open("/tmp/Daring-Anteater/train.jsonl") as f:
+import json, os
+from pathlib import Path
+from dataclasses import dataclass, field
+from eagle_utils import DataCollatorWithPadding, LazySupervisedDataset
+from transformers import Trainer
+
+data_path = Path(os.getenv("DA_TRAIN_PATH", "/tmp/Daring-Anteater/train.jsonl"))
+if not data_path.exists():
+    raise FileNotFoundError(f"Dataset not found at {data_path}. Set DA_TRAIN_PATH or run the download cell.")
+with open(data_path) as f:
     data_json = [json.loads(line) for line in f]

If adopting datasets.load_dataset above, use ds instead of opening a file.


113-121: Don’t call private Trainer methods.

Remove _move_model_to_device; let Trainer handle placement (after removing device_map).

Apply:

 trainer = Trainer(
     model=model,
     tokenizer=tokenizer,
     args=training_args,
     train_dataset=train_dataset,
     eval_dataset=eval_dataset,
     data_collator=DataCollatorWithPadding(),
 )
-trainer._move_model_to_device(model, trainer.args.device)

45-46: Remove device_map when using Trainer (let Trainer place the model).

This conflicts with Trainer and led you to call a private API later.

Apply:

-model = transformers.AutoModelForCausalLM.from_pretrained(
-    base_model, torch_dtype="auto", device_map="cuda"
-)
+model = transformers.AutoModelForCausalLM.from_pretrained(
+    base_model, torch_dtype="auto"
+)

219-246: Avoid --net=host; mount HF cache for gated models.

Use explicit -p and mount ~/.cache/huggingface so the container sees previously downloaded weights.

Apply:

-import subprocess
-import threading
-
-# Generate a unique container name so we can stop/remove it later
-container_name = "trtllm_serve_spec"
+import subprocess, threading, os, uuid
+container_name = f"trtllm_serve_spec-{uuid.uuid4().hex[:8]}"
+home_dir = os.path.expanduser("~")
+hf_cache_dir = os.path.join(home_dir, ".cache", "huggingface")
+os.makedirs(hf_cache_dir, exist_ok=True)
 
 docker_cmd = [
     "docker",
     "run",
     "--rm",
-    "--net",
-    "host",
+    "-p","8000:8000",
     "--shm-size=2g",
     "--ulimit",
     "memlock=-1",
     "--ulimit",
     "stack=67108864",
     "--gpus",
     "all",
     "-v",
     "/tmp:/tmp",
+    "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
+    "-e","HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
     "--name",
     container_name,
     "nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2",
     "bash",
     "-c",
     "bash /tmp/trtllm_serve.sh",
 ]

369-401: SGLang: avoid --net=host; publish port explicitly.

Mirror the TRT‑LLM fix and keep HF cache mount.

Apply:

-    "--net",
-    "host",
+    "-p","30000:30000",

61-62: Freeze base model; train only the draft (“eagle”) module as stated.

Currently all params appear trainable; freeze non‑EAGLE params to match the doc text and reduce memory/compute.

Apply:

 mtsp.convert(model, [("eagle", config)])
+
+# Freeze base model; train only EAGLE params
+for name, p in model.named_parameters():
+    p.requires_grad = ("eagle" in name)
+
+trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+total = sum(p.numel() for p in model.parameters())
+print(f"Trainable params: {trainable:,} / {total:,}")
🧹 Nitpick comments (9)
examples/speculative_decoding/example.ipynb (9)

17-18: Prefer datasets/snapshot_download over git clone to avoid LFS issues.

Use HF Datasets or huggingface_hub to reliably fetch JSONL (works without git-lfs).

Apply:

-!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater
+from datasets import load_dataset
+ds = load_dataset("nvidia/Daring-Anteater", split="train")
+# If you still want a local file: ds.to_json("/tmp/Daring-Anteater/train.jsonl", lines=True)

65-72: Don’t rely on passing chat_template to servers; pre-render on client.

Some servers ignore client-supplied templates. Render text with tokenizer.apply_chat_template and send a plain prompt.

Example:

prompt = tokenizer.apply_chat_template(payload["messages"], tokenize=False, add_generation_prompt=True)

Then call the server API with a prompt/completions endpoint accordingly.


101-112: bf16 may not be supported on all GPUs; add a safe fallback.

Guard bf16 or expose dtype via arg/env.

Apply:

-    bf16: bool = field(default=True)
+    bf16: bool = field(default=False)
+    fp16: bool = field(default=True)

Or set at construction:

use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
training_args = TrainingArguments(..., bf16=use_bf16, fp16=not use_bf16)

172-204: TRT‑LLM: set conservative limits; keep extra config small.

Minor: consider matching max_batch_size to expected interactive usage (e.g., 1–4) and avoid redundant max_num_tokens & max_seq_len both at 8192.


283-305: Add timeout and error handling to requests.

Prevent notebook hangs and surface server errors.

Apply:

-response = requests.post(
-    "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
+response = requests.post(
+    "http://localhost:8000/v1/chat/completions",
+    headers=headers,
+    data=json.dumps(payload),
+    timeout=60,
+)
+response.raise_for_status()
+output = response.json()

320-321: Use the same container_name variable for cleanup.

The hardcoded name can drift if you randomize names as suggested above.

Apply:

-!docker rm -f trtllm_serve_spec
+import subprocess
+subprocess.run(["docker","rm","-f", container_name], check=False)

438-460: Add timeout and error handling to SGLang client call.

Same rationale as TRT‑LLM cell.

Apply:

-response = requests.post(
-    "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
+response = requests.post(
+    "http://localhost:30000/v1/chat/completions",
+    headers=headers,
+    data=json.dumps(payload),
+    timeout=60,
+)
+response.raise_for_status()
+output = response.json()

475-476: Use variable container_name for SGLang cleanup.

Keep cleanup robust with dynamic names.

Apply:

-!docker rm -f sglang_serve_spec
+subprocess.run(["docker","rm","-f", container_name], check=False)

495-510: Kernel name nit: not vLLM‑only.

Display name “modelopt+vllm” may mislead since the notebook targets TRT‑LLM and SGLang; consider a neutral name.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 83e06df and 7035565.

📒 Files selected for processing (1)
  • examples/speculative_decoding/example.ipynb (12 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.961Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
examples/speculative_decoding/example.ipynb (4)

129-130: LGTM: saving model and tokenizer.

Save paths are consistent with output_dir.


148-152: LGTM: export_hf_checkpoint usage.

Export step aligns with subsequent serving configs.


48-59: Avoid mutating the imported default EAGLE3 config.

You update a nested dict in-place; this can leak state across runs/imports.

Apply:

+import copy
-# Read Default Config for EAGLE3
-config = EAGLE3_DEFAULT_CFG["config"]
+# Read Default Config for EAGLE3 (deep copy to avoid global mutation)
+config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])
⛔ Skipped due to learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.

242-242: Pin container tags and verify availability.

examples/speculative_decoding/example.ipynb (lines 242, 396): replace nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2 and lmsysorg/sglang:latest with explicit, stable tags and confirm those image tags exist before publishing.

TensorRT‑LLM v1.1.0rc2 is published as a release. (github.com)
SGLang docs reference lmsysorg/sglang:latest — avoid :latest and pin a concrete release/tag. (docs.sglang.ai)

I could not run the provided skopeo check in the sandbox (skopeo not installed); run the original skopeo commands locally (or docker pull the pinned tags) to verify image availability.

@h-guo18 h-guo18 merged commit 461980e into main Sep 17, 2025
22 checks passed
@h-guo18 h-guo18 deleted the haoguo/update-eagle-notebook-example branch September 17, 2025 21:17
yeyu-nvidia pushed a commit that referenced this pull request Sep 18, 2025
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants