Skip to content

Conversation

jamieliNVIDIA
Copy link

@jamieliNVIDIA jamieliNVIDIA commented Sep 14, 2025

What does this PR do?

**Type of change: documentation ** ?

Overview: Updated the example.ipynb for speculative decoding to deploy on sglang and trt-llm ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • EAGLE3-based speculative decoding workflow and checkpoint export for reuse.
    • TRT-LLM and SGLang deployment flows with Docker-based serving and HTTP testing.
  • Changes

    • Switched to a Meta Llama 3 base and Daring-Anteater dataset; training increased (1→4 epochs); evaluation step removed; model conversion and tokenizer setup updated.
  • Breaking Changes

    • TrainingArguments no longer includes cache_dir and model_max_length.
  • Documentation

    • Notebook narrative reorganized for end-to-end data prep, training, export, and deployment.
  • Chores

    • VSCode workspace git settings updated (commit signing/sign-off).

@jamieliNVIDIA jamieliNVIDIA requested a review from a team as a code owner September 14, 2025 22:56
Copy link

copy-pr-bot bot commented Sep 14, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 14, 2025

Walkthrough

The notebook replaces an FP8/quantization-focused speculative decoding example with an EAGLE3-based workflow: it switches to meta-llama/Llama-3.2-1B, prepares data from Daring-Anteater, attaches an EAGLE draft head via mtsp.convert, updates training/export, and adds TRT-LLM and SGLang deployment steps.

Changes

Cohort / File(s) Summary
Notebook (Speculative Decoding)
examples/speculative_decoding/example.ipynb
Reworked example to use an EAGLE3 workflow; replaced FP8/data-synthesis narrative with Data & Model Preparation, updated headings/metadata.
Model & Conversion
examples/speculative_decoding/example.ipynb
Base model changed to meta-llama/Llama-3.2-1B (loaded with device_map="cuda"); added EAGLE3_DEFAULT_CFG/eagle_architecture_config; call to mtsp.convert(model, [("eagle", config)]); tokenizer init adjusted (model_max_length=1024, chat_template fallback).
Data & Dataset Handling
examples/speculative_decoding/example.ipynb
Switched from synthetic dataset to Daring-Anteater; download via shell/git clone; training data path updated to /tmp/Daring-Anteater/train.jsonl.
Training & Args
examples/speculative_decoding/example.ipynb
Training epochs increased to 4; TrainingArguments subclass signature altered (removed cache_dir and model_max_length fields); data collator import path changed; evaluation removed; training then checkpoint export.
Checkpoint Export
examples/speculative_decoding/example.ipynb
Added export_hf_checkpoint step to save HF checkpoint to /tmp/hf_ckpt.
Deployment: TRT-LLM
examples/speculative_decoding/example.ipynb
Added TRT-LLM deployment flow: generates trtllm-serve script and extra_llm_api_config.yml, configures kv_cache/speculative options, launches Docker, streams logs, runs HTTP tests, and cleans up.
Deployment: SGLang
examples/speculative_decoding/example.ipynb
Added SGLang deployment flow: creates sglang_serve.sh, launches Docker container, runs chat completion test, and cleans up container.
Editor Settings
.vscode/settings.json
Added Git settings (git.enableCommitSigning: true, git.alwaysSignOff: true) and fixed JSON trailing syntax for python.analysis.extraPaths.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Notebook
  participant HF as HuggingFace Model
  participant MT as mtsp.convert
  participant DS as Dataset (/tmp/Daring-Anteater)
  participant Trainer
  participant FS as Filesystem

  User->>Notebook: execute cells
  Notebook->>HF: load base model & tokenizer (Llama-3.2-1B)
  Notebook->>MT: mtsp.convert(model, [("eagle", cfg)]) -> attach EAGLE draft head
  Notebook->>DS: load /tmp/Daring-Anteater/train.jsonl
  Notebook->>Trainer: train (num_train_epochs=4)
  Trainer-->>Notebook: trained weights
  Notebook->>FS: export_hf_checkpoint -> /tmp/hf_ckpt
Loading
sequenceDiagram
  autonumber
  actor Client
  participant TRT as TRT-LLM Server
  participant SG as SGLang Server
  participant DH as EAGLE Draft Head
  participant BH as Base Model

  rect rgba(200,230,255,0.18)
    note over TRT: TRT-LLM speculative flow
    Client->>TRT: Chat/completion request
    TRT->>DH: request draft tokens
    DH-->>TRT: draft + acceptance info
    TRT->>BH: verify/complete tokens
    BH-->>TRT: final tokens
    TRT-->>Client: response
  end

  rect rgba(220,255,220,0.18)
    note over SG: SGLang flow (similar)
    Client->>SG: Chat request
    SG->>DH: draft generation
    SG->>BH: verify/complete
    BH-->>SG: final tokens
    SG-->>Client: response
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I twitch my whiskers for EAGLE’s flight,
Draft tokens dance in morning light.
Four rounds of training, checkpoints snug,
Docker drums hum—deploy, give a hug.
TRT and SGLang cheer—hop, speculative bug! 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Update eagle notebook example with sglang" is concise and directly related to a clear, visible change in the changeset (the example.ipynb now includes SGLang deployment). It is not vague or off-topic and therefore does not fail the relevancy check. However, the PR also introduces TRT-LLM deployment and broader EAGLE3/model conversion and training updates that the title does not mention.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/example.ipynb (1)

114-121: Only train the EAGLE head; freeze the target model.

As-is, Trainer will update the full model, which is likely unintended/expensive.

Apply before creating Trainer:

+# Freeze base/target model; train only the EAGLE head
+for name, p in model.named_parameters():
+    p.requires_grad = ("eagle" in name.lower())

If modelopt offers a helper (e.g., mtsp.freeze_base_model/enable_train("eagle")), prefer that.

🧹 Nitpick comments (8)
examples/speculative_decoding/example.ipynb (8)

7-8: Call out prerequisites (HF auth + git‑lfs) up front.

Daring‑Anteater via git and Llama‑3.2‑1B are likely gated; readers may need HF tokens and git‑lfs. Add a short “Prereqs” note here to reduce setup failures.


43-47: HF gated model note.

meta-llama/Llama-3.2-1B typically requires license acceptance and an HF token. Consider documenting HF auth/env var usage before this cell to prevent download errors.


62-63: Capture/verify convert() return semantics.

Some modelopt APIs return a new module; others mutate in place. Assign defensively or confirm it’s in‑place to avoid silently training the wrong module.

Apply:

-mtsp.convert(model, [("eagle", config)])
+_ret = mtsp.convert(model, [("eagle", config)])
+if _ret is not None:
+    model = _ret

108-113: Set a seed for reproducibility.

Add a deterministic seed to make the demo repeatable.

Apply before creating Trainer:

+transformers.set_seed(42)

318-320: Make cleanup idempotent.

Apply:

-%%sh
-docker rm -f trtllm_serve_spec
+%%sh
+docker rm -f trtllm_serve_spec || true

375-387: Propagate HF cache/token into SGLang container too (mirror TRT‑LLM).

You already mount the cache—optionally plumb the token as an env var for gated models.

Apply:

-    "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
+    "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
+    "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
+    # "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",

459-461: Idempotent cleanup.

Apply:

-%%sh
-docker rm -f sglang_serve_spec
+%%sh
+docker rm -f sglang_serve_spec || true

480-480: Kernel name nit.

“modelopt+serve” is a clearer display_name since this notebook demos TRT‑LLM and SGLang, not vLLM (yet).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and 9436a44.

📒 Files selected for processing (1)
  • examples/speculative_decoding/example.ipynb (12 hunks)
🔇 Additional comments (5)
examples/speculative_decoding/example.ipynb (5)

147-154: LGTM: HF checkpoint export step is clear.


336-349: LGTM: SGLang launch script looks reasonable.

If multi‑GPU is common for your users, consider exposing tensor‑parallel flags in the script.


494-494: Python 3.12 support confirmed (as of 2025-09-14). PyTorch, Transformers, NVIDIA TensorRT‑LLM, and SGLang official docs/releases list Python 3.12 support; no changes required.


183-197: Speculative config — set decoding_type to "Eagle"; keep autotuner off; enable block reuse if supported.

  • decoding_type must be "Eagle" (capital E).
  • enable_autotuner is experimental — do not enable by default in production; use offline tuning and validate before enabling.
  • kv_cache_config.enable_block_reuse is recommended for production (improves TTFT) only if the model was built with paged‑context/FMHA support (validate model build and behavior before enabling).

File: examples/speculative_decoding/example.ipynb (lines 183–197)


93-93: Verify eagle_utils classes & pad-token masking.

Confirm examples/speculative_decoding/eagle_utils.py defines DataCollatorWithPadding (found ~line 207) and LazySupervisedDataset, and verify DataCollatorWithPadding sets labels for pad tokens to -100 (i.e., masks positions where input_ids == pad_token_id, including the case pad_token_id == eos_token_id).

Comment on lines 17 to 19
"%%sh\n",
"git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater"
]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Avoid git clone for datasets; use huggingface_hub to dodge git‑lfs pitfalls.

git‑lfs often isn’t installed in notebook environments; snapshot_download is more reliable and cache‑friendly.

Apply this diff in the cell:

-%%sh
-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater
+from huggingface_hub import snapshot_download
+snapshot_download(
+    repo_id="nvidia/Daring-Anteater",
+    repo_type="dataset",
+    local_dir="/tmp/Daring-Anteater",
+    local_dir_use_symlinks=False
+)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"%%sh\n",
"git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater"
]
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="nvidia/Daring-Anteater",
repo_type="dataset",
local_dir="/tmp/Daring-Anteater",
local_dir_use_symlinks=False
)
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 17 to 19, the
notebook currently runs a shell git clone which can fail due to missing git‑lfs;
replace the shell git clone with a Python-based download using
huggingface_hub.snapshot_download to fetch the dataset to /tmp (or a configured
cache dir). Import huggingface_hub, call
snapshot_download(repo_id="nvidia/Daring-Anteater",
cache_dir="/tmp/Daring-Anteater" or allow default cache), and use the returned
path in subsequent cells; ensure to handle authentication/token if required and
add a short note to pip-install huggingface_hub if not present.

Comment on lines +49 to +61
"# Read Default Config for EAGLE3\n",
"config = EAGLE3_DEFAULT_CFG[\"config\"]\n",
"\n",
"# Hidden size and vocab size must match base model\n",
"config[\"eagle_architecture_config\"].update(\n",
" {\n",
" \"hidden_size\": model.config.hidden_size,\n",
" \"vocab_size\": model.config.vocab_size,\n",
" \"draft_vocab_size\": model.config.vocab_size,\n",
" \"max_position_embeddings\": model.config.max_position_embeddings,\n",
" }\n",
")\n",
"\n",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Do not mutate the global default EAGLE3 config.

Updating EAGLE3_DEFAULT_CFG["config"] in place can affect subsequent runs/calls. Deep‑copy before mutation.

Apply:

-from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
+from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
+import copy
@@
-# Read Default Config for EAGLE3
-config = EAGLE3_DEFAULT_CFG["config"]
+# Read Default Config for EAGLE3 (copy to avoid global mutation)
+config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"# Read Default Config for EAGLE3\n",
"config = EAGLE3_DEFAULT_CFG[\"config\"]\n",
"\n",
"# Hidden size and vocab size must match base model\n",
"config[\"eagle_architecture_config\"].update(\n",
" {\n",
" \"hidden_size\": model.config.hidden_size,\n",
" \"vocab_size\": model.config.vocab_size,\n",
" \"draft_vocab_size\": model.config.vocab_size,\n",
" \"max_position_embeddings\": model.config.max_position_embeddings,\n",
" }\n",
")\n",
"\n",
import copy
# Read Default Config for EAGLE3 (copy to avoid global mutation)
config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])
# Hidden size and vocab size must match base model
config["eagle_architecture_config"].update(
{
"hidden_size": model.config.hidden_size,
"vocab_size": model.config.vocab_size,
"draft_vocab_size": model.config.vocab_size,
"max_position_embeddings": model.config.max_position_embeddings,
}
)
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 49 to 61, the code
updates EAGLE3_DEFAULT_CFG["config"] in place which mutates the global default;
instead, create a deep copy of EAGLE3_DEFAULT_CFG["config"] (or of
EAGLE3_DEFAULT_CFG) into a local variable and perform the update on that copy,
then use the copied config for subsequent initialization so the global default
remains unchanged.

Comment on lines +65 to +73
"# Prepare Tokenizer\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)\n",
"tokenizer.pad_token_id = tokenizer.eos_token_id\n",
"if tokenizer.chat_template is None:\n",
" tokenizer.chat_template = (\n",
" \"{%- for message in messages %}\"\n",
" \"{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}\"\n",
" \"{%- endfor %}\"\n",
" )"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Align tokenizer limits to model; set padding side.

1024 may truncate unnecessarily if the base model supports more. Also set padding_side explicitly to avoid surprises in collators/loss masks.

Apply:

-tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)
+tokenizer = transformers.AutoTokenizer.from_pretrained(
+    base_model,
+    model_max_length=getattr(model.config, "max_position_embeddings", 2048),
+)
+tokenizer.padding_side = "right"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"# Prepare Tokenizer\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)\n",
"tokenizer.pad_token_id = tokenizer.eos_token_id\n",
"if tokenizer.chat_template is None:\n",
" tokenizer.chat_template = (\n",
" \"{%- for message in messages %}\"\n",
" \"{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}\"\n",
" \"{%- endfor %}\"\n",
" )"
# Prepare Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
base_model,
model_max_length=getattr(model.config, "max_position_embeddings", 2048),
)
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id
if tokenizer.chat_template is None:
tokenizer.chat_template = (
"{%- for message in messages %}"
"{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}"
"{%- endfor %}"
)
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 65–73, replace the
hardcoded model_max_length=1024 with the model's actual limit and explicitly set
padding_side: load the model config (AutoConfig.from_pretrained(base_model)) and
use its max_position_embeddings (or tokenizer.model_max_length if already
provided) to set tokenizer.model_max_length, then set tokenizer.padding_side =
"right" (or "left" if your training expects left padding) alongside
tokenizer.pad_token_id and keep the chat_template logic unchanged.

Comment on lines +96 to 99
"with open(\"/tmp/Daring-Anteater/train.jsonl\") as f:\n",
" data_json = [json.loads(line) for line in f]\n",
"train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)\n",
"eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)\n",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid loading the entire JSONL into memory.

Stream or use datasets to handle large files robustly.

Apply:

-with open("/tmp/Daring-Anteater/train.jsonl") as f:
-    data_json = [json.loads(line) for line in f]
-train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
-eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)
+import json
+def stream_jsonl(path):
+    with open(path) as f:
+        for line in f:
+            yield json.loads(line)
+all_data = list(stream_jsonl("/tmp/Daring-Anteater/train.jsonl"))
+split = int(len(all_data) * 0.95)
+train_dataset = LazySupervisedDataset(all_data[:split], tokenizer=tokenizer)
+eval_dataset = LazySupervisedDataset(all_data[split:], tokenizer=tokenizer)

Or swap to datasets.load_dataset("json", data_files=...).

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 96 to 99, the
notebook currently reads the entire JSONL into memory via json.loads on each
line; replace this with a streaming approach or use the datasets library. Either
(a) iterate the file and yield/process lines lazily (e.g., create generators or
a LazySupervisedDataset that reads from the file path rather than a pre-built
list) so you never materialize data_json, or (b) call
datasets.load_dataset("json", data_files=path, split="train",
streaming=True/with proper train/validation splits) and pass the resulting
dataset (or its iterator) into LazySupervisedDataset to avoid loading the full
file into memory. Ensure subsequent slicing/splitting is done via
streaming-aware methods (e.g., dataset.train_test_split or manual streaming
partition) rather than list indexing.

Comment on lines 102 to 106
"@dataclass\n",
"class TrainingArguments(transformers.TrainingArguments):\n",
" cache_dir: str | None = field(default=None)\n",
" model_max_length: int = field(\n",
" default=4096,\n",
" metadata={\n",
" \"help\": (\n",
" \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n",
" )\n",
" },\n",
" )\n",
" dataloader_drop_last: bool = field(default=True)\n",
" bf16: bool = field(default=True)\n",
"\n",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Don’t subclass TrainingArguments for minor toggles.

Subclassing a dataclass here adds maintenance risk for little gain. Pass bf16 and dataloader_drop_last directly.

Apply:

-@dataclass
-class TrainingArguments(transformers.TrainingArguments):
-    dataloader_drop_last: bool = field(default=True)
-    bf16: bool = field(default=True)
+TrainingArguments = transformers.TrainingArguments

And below (Line 108):

 training_args = TrainingArguments(
     output_dir="/tmp/eagle_bf16",
     num_train_epochs=4,
     per_device_train_batch_size=1,
     per_device_eval_batch_size=1,
+    bf16=True,
+    dataloader_drop_last=True,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"@dataclass\n",
"class TrainingArguments(transformers.TrainingArguments):\n",
" cache_dir: str | None = field(default=None)\n",
" model_max_length: int = field(\n",
" default=4096,\n",
" metadata={\n",
" \"help\": (\n",
" \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n",
" )\n",
" },\n",
" )\n",
" dataloader_drop_last: bool = field(default=True)\n",
" bf16: bool = field(default=True)\n",
"\n",
TrainingArguments = transformers.TrainingArguments
training_args = TrainingArguments(
output_dir="/tmp/eagle_bf16",
num_train_epochs=4,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
bf16=True,
dataloader_drop_last=True,
)
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 102 to 106 (and the
instantiation at line 108), a custom @dataclass subclass of
transformers.TrainingArguments is being used just to toggle dataloader_drop_last
and bf16; remove this subclass entirely and stop maintaining a custom dataclass.
Replace uses of the custom class with the standard
transformers.TrainingArguments and pass dataloader_drop_last=True and bf16=True
as keyword arguments when creating the TrainingArguments instance (adjust the
instantiation at line 108 accordingly).

Comment on lines +173 to +181
"trtllm_serve_script = f\"\"\"trtllm-serve {base_model} \\\\\n",
" --host 0.0.0.0 \\\\\n",
" --port 8000 \\\\\n",
" --backend pytorch \\\\\n",
" --max_batch_size 32 \\\\\n",
" --max_num_tokens 8192 \\\\\n",
" --max_seq_len 8192 \\\\\n",
" --extra_llm_api_options /tmp/extra-llm-api-config.yml\n",
"\"\"\"\n",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Bound server seq/token limits to model/tokenizer.

Hard‑coding 8192 may conflict with the base model’s limit. Use dynamic values.

Apply:

-trtllm_serve_script = f"""trtllm-serve {base_model} \\
+trtllm_serve_script = f"""trtllm-serve {base_model} \\
     --host 0.0.0.0 \\
     --port 8000 \\
     --backend pytorch \\
     --max_batch_size 32 \\
-    --max_num_tokens 8192 \\
-    --max_seq_len 8192 \\
+    --max_num_tokens {tokenizer.model_max_length} \\
+    --max_seq_len {tokenizer.model_max_length} \\
     --extra_llm_api_options /tmp/extra-llm-api-config.yml
 """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"trtllm_serve_script = f\"\"\"trtllm-serve {base_model} \\\\\n",
" --host 0.0.0.0 \\\\\n",
" --port 8000 \\\\\n",
" --backend pytorch \\\\\n",
" --max_batch_size 32 \\\\\n",
" --max_num_tokens 8192 \\\\\n",
" --max_seq_len 8192 \\\\\n",
" --extra_llm_api_options /tmp/extra-llm-api-config.yml\n",
"\"\"\"\n",
"trtllm_serve_script = f\"\"\"trtllm-serve {base_model} \\\\\n",
" --host 0.0.0.0 \\\\\n",
" --port 8000 \\\\\n",
" --backend pytorch \\\\\n",
" --max_batch_size 32 \\\\\n",
" --max_num_tokens {tokenizer.model_max_length} \\\\\n",
" --max_seq_len {tokenizer.model_max_length} \\\\\n",
" --extra_llm_api_options /tmp/extra-llm-api-config.yml\n",
"\"\"\"\n",
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 173 to 181, the
script hardcodes --max_num_tokens and --max_seq_len to 8192 which can exceed the
base model/tokenizer limits; change it to compute safe limits at runtime by
loading the model or tokenizer (e.g., tokenizer.model_max_length or
model.config.max_position_embeddings), then set max_num_tokens and max_seq_len
to the smaller of that model limit and your desired target (with a sensible
fallback value if the config is missing). Ensure the trtllm_serve_script string
injects these computed integers rather than the literal 8192 so the server flags
respect the actual model/tokenizer capacity.

Comment on lines +220 to +247
"import subprocess\n",
"import threading\n",
"\n",
"# Generate a unique container name so we can stop/remove it later\n",
"container_name = \"trtllm_serve_spec\"\n",
"\n",
"docker_cmd = [\n",
" \"docker\",\n",
" \"run\",\n",
" \"--rm\",\n",
" \"--net\",\n",
" \"host\",\n",
" \"--shm-size=2g\",\n",
" \"--ulimit\",\n",
" \"memlock=-1\",\n",
" \"--ulimit\",\n",
" \"stack=67108864\",\n",
" \"--gpus\",\n",
" \"all\",\n",
" \"-v\",\n",
" \"/tmp:/tmp\",\n",
" \"--name\",\n",
" container_name,\n",
" \"nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2\",\n",
" \"bash\",\n",
" \"-c\",\n",
" \"bash /tmp/trtllm_serve.sh\",\n",
"]\n",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Mount HF cache (+token) and increase shared memory for TRT‑LLM.

Without the HF cache/token in the container, model download can fail; 2g shm is often too small.

Apply:

-    "--shm-size=2g",
+    "--shm-size=32g",
@@
-    "-v",
-    "/tmp:/tmp",
+    "-v", "/tmp:/tmp",
+    "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
+    "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
+    "-e", "HF_HUB_ENABLE_HF_TRANSFER=1",
+    # optionally pass a token if needed:
+    # "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",

Also consider pre‑removing existing containers:

-# Generate a unique container name so we can stop/remove it later
-container_name = "trtllm_serve_spec"
+# Use a deterministic name and ensure it is not left over from prior runs
+container_name = "trtllm_serve_spec"
+subprocess.call(["docker","rm","-f",container_name])

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 220–247, the Docker
run command lacks mounting the Hugging Face cache/token and uses too small
shared memory; update the docker_cmd to: add volume mounts mapping the host HF
cache and token into the container (e.g. host ~/.cache/huggingface -> container
/root/.cache/huggingface and host ~/.huggingface or token file -> container
/root/.huggingface or appropriate path) so model downloads and auth work inside
the container, increase --shm-size from "2g" to a larger value (e.g. "8g" or
"16g") to avoid OOM on TRT‑LLM, and add a pre-run step to remove any existing
container with the same name (docker rm -f <container_name>) before starting the
new container.

Comment on lines +282 to 303
"import json\n",
"import requests\n",
"\n",
"from modelopt.torch.export import export_hf_checkpoint\n",
"payload = {\n",
" \"model\": base_model,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n",
" ],\n",
" \"max_tokens\": 512,\n",
" \"temperature\": 0,\n",
" \"chat_template\": tokenizer.chat_template,\n",
"}\n",
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n",
"\n",
"# Move meta tensor back to device before exporting.\n",
"remove_hook_from_module(model, recurse=True)\n",
"response = requests.post(\n",
" \"http://localhost:8000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n",
")\n",
"output = response.json()\n",
"\n",
"export_hf_checkpoint(\n",
" model,\n",
" export_dir=\"/tmp/hf_ckpt\",\n",
")"
"print(output)"
]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Harden the request and drop non‑standard fields.

OpenAI Chat Completions may not accept chat_template; add error handling.

Apply:

-    "chat_template": tokenizer.chat_template,
 }
 headers = {"Content-Type": "application/json", "Accept": "application/json"}
 
-response = requests.post(
-    "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
-
-print(output)
+resp = requests.post(
+    "http://localhost:8000/v1/chat/completions",
+    headers=headers,
+    data=json.dumps(payload),
+    timeout=60,
+)
+resp.raise_for_status()
+print(resp.json())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"import json\n",
"import requests\n",
"\n",
"from modelopt.torch.export import export_hf_checkpoint\n",
"payload = {\n",
" \"model\": base_model,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n",
" ],\n",
" \"max_tokens\": 512,\n",
" \"temperature\": 0,\n",
" \"chat_template\": tokenizer.chat_template,\n",
"}\n",
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n",
"\n",
"# Move meta tensor back to device before exporting.\n",
"remove_hook_from_module(model, recurse=True)\n",
"response = requests.post(\n",
" \"http://localhost:8000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n",
")\n",
"output = response.json()\n",
"\n",
"export_hf_checkpoint(\n",
" model,\n",
" export_dir=\"/tmp/hf_ckpt\",\n",
")"
"print(output)"
]
import json
import requests
payload = {
"model": base_model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me about speculative decoding."},
],
"max_tokens": 512,
"temperature": 0,
}
headers = {"Content-Type": "application/json", "Accept": "application/json"}
resp = requests.post(
"http://localhost:8000/v1/chat/completions",
headers=headers,
data=json.dumps(payload),
timeout=60,
)
resp.raise_for_status()
print(resp.json())
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 282 to 303, the POST
payload includes a non-standard "chat_template" field and lacks proper error
handling; remove "chat_template" from the payload, send JSON using requests'
json= parameter (not data=json.dumps), check response.status_code and raise or
print a helpful error when it's not 200, wrap the request in a try/except to
catch requests.exceptions (timeout/connection errors) and set a reasonable
timeout, and ensure headers remain Content-Type: application/json and Accept:
application/json.

Comment on lines 423 to 444
"import json\n",
"import requests\n",
"\n",
"payload = {\n",
" \"model\": base_model,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n",
" ],\n",
" \"max_tokens\": 512,\n",
" \"temperature\": 0,\n",
"}\n",
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n",
"\n",
"#Send request to the SGLang server\n",
"response = requests.post(\n",
" \"http://localhost:30000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n",
")\n",
"output = response.json()\n",
"\n",
"print(output['choices'][0]['message']['content'])"
]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling on the SGLang request.

Apply:

-response = requests.post(
+response = requests.post(
     "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload)
 )
-output = response.json()
-
-print(output['choices'][0]['message']['content'])
+response.raise_for_status()
+output = response.json()
+print(output.get("choices", [{}])[0].get("message", {}).get("content", output))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"import json\n",
"import requests\n",
"\n",
"payload = {\n",
" \"model\": base_model,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n",
" ],\n",
" \"max_tokens\": 512,\n",
" \"temperature\": 0,\n",
"}\n",
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n",
"\n",
"#Send request to the SGLang server\n",
"response = requests.post(\n",
" \"http://localhost:30000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n",
")\n",
"output = response.json()\n",
"\n",
"print(output['choices'][0]['message']['content'])"
]
import json
import requests
payload = {
"model": base_model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me about speculative decoding."},
],
"max_tokens": 512,
"temperature": 0,
}
headers = {"Content-Type": "application/json", "Accept": "application/json"}
#Send request to the SGLang server
response = requests.post(
"http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload)
)
response.raise_for_status()
output = response.json()
print(output.get("choices", [{}])[0].get("message", {}).get("content", output))
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 423 to 444, the code
sends the SGLang request without any error handling; wrap the requests.post call
in a try/except that catches requests.exceptions.RequestException (use a
reasonable timeout), check response.raise_for_status() or response.ok and
log/raise a descriptive error including response.status_code and response.text
if the call failed, then safely parse JSON inside a try/except for
ValueError/JSONDecodeError and handle missing keys before printing (provide a
clear fallback or error message).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (10)
examples/speculative_decoding/example.ipynb (10)

17-19: Replace git clone with huggingface_hub to avoid git‑lfs pitfalls.

Use snapshot_download for reliability and caching.

-%%sh
-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater
+from huggingface_hub import snapshot_download
+dataset_dir = "/tmp/Daring-Anteater"
+snapshot_download(
+    repo_id="nvidia/Daring-Anteater",
+    repo_type="dataset",
+    local_dir=dataset_dir,
+    local_dir_use_symlinks=False,
+)

66-73: Align tokenizer limits to model; set padding side; avoid ChatML fallback.

Hardcoding 1024 may truncate; ChatML fallback can mismatch Llama‑3.2.

-tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)
-tokenizer.pad_token_id = tokenizer.eos_token_id
-if tokenizer.chat_template is None:
-    tokenizer.chat_template = (
-        "{%- for message in messages %}"
-        "{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}"
-        "{%- endfor %}"
-    )
+model_max_len = getattr(getattr(model, "config", None), "max_position_embeddings", None) or 2048
+tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=model_max_len)
+tokenizer.padding_side = "right"
+tokenizer.pad_token_id = tokenizer.eos_token_id
+# Avoid overriding tokenizer.chat_template with a generic template.

96-99: Stream dataset or use datasets library; avoid loading entire JSONL into memory.

-with open("/tmp/Daring-Anteater/train.jsonl") as f:
-    data_json = [json.loads(line) for line in f]
-train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
-eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)
+from datasets import load_dataset
+ds = load_dataset("json", data_files={"/all": "/tmp/Daring-Anteater/train.jsonl"})["/all"]
+splits = ds.train_test_split(test_size=0.05, seed=42)
+train_dataset = LazySupervisedDataset(splits["train"], tokenizer=tokenizer)
+eval_dataset = LazySupervisedDataset(splits["test"], tokenizer=tokenizer)

102-113: Don’t subclass TrainingArguments for minor toggles; pass flags directly.

-from dataclasses import dataclass, field
-@dataclass
-class TrainingArguments(transformers.TrainingArguments):
-    dataloader_drop_last: bool = field(default=True)
-    bf16: bool = field(default=True)
-
-
-training_args = TrainingArguments(
+TrainingArguments = transformers.TrainingArguments
+training_args = TrainingArguments(
     output_dir="/tmp/eagle_bf16",
     num_train_epochs=4,
     per_device_train_batch_size=1,
     per_device_eval_batch_size=1,
+    bf16=True,
+    dataloader_drop_last=True,
 )

122-122: Avoid private API Trainer._move_model_to_device().

Let Trainer handle placement or call model.to(...) explicitly.

-trainer._move_model_to_device(model, trainer.args.device)
+# Let Trainer manage device placement.

173-181: Bound max tokens/seq to tokenizer/model; don’t hardcode 8192.

-    --max_num_tokens 8192 \
-    --max_seq_len 8192 \
+    --max_num_tokens {tokenizer.model_max_length} \
+    --max_seq_len {tokenizer.model_max_length} \

220-247: Harden TRT‑LLM docker run: shm, HF cache/token, cleanup.

-# Generate a unique container name so we can stop/remove it later
-container_name = "trtllm_serve_spec"
+container_name = "trtllm_serve_spec"
+subprocess.call(["docker","rm","-f",container_name])
@@
-    "--shm-size=2g",
+    "--shm-size=32g",
@@
-    "-v",
-    "/tmp:/tmp",
+    "-v", "/tmp:/tmp",
+    "-v", f"{os.path.expanduser('~')}/.cache/huggingface:/root/.cache/huggingface",
+    "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
+    "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",
What is the latest stable NGC tag for nvcr.io/nvidia/tensorrt-llm/release and does trtllm-serve accept --backend pytorch with speculative_config?

282-303: Remove non‑standard chat_template and add robust HTTP error handling.

-    "temperature": 0,
-    "chat_template": tokenizer.chat_template,
+    "temperature": 0,
@@
-response = requests.post(
-    "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
-
-print(output)
+resp = requests.post(
+    "http://localhost:8000/v1/chat/completions",
+    headers=headers,
+    json=payload,
+    timeout=60,
+)
+resp.raise_for_status()
+print(resp.json())

423-444: Add error handling to SGLang request.

-response = requests.post(
-    "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
-
-print(output)
+resp = requests.post(
+    "http://localhost:30000/v1/chat/completions",
+    headers=headers,
+    json=payload,
+    timeout=60,
+)
+resp.raise_for_status()
+print(resp.json().get("choices", [{}])[0].get("message", {}).get("content"))

39-60: Do not mutate EAGLE3_DEFAULT_CFG in place; deep‑copy before updates.

Global mutation risks bleed‑through across runs/imports.

Note: I’m aware of the prior learning that ALGO_TO_CONFIG yields deep copies in tests; here we’re using EAGLE3_DEFAULT_CFG directly, so copy locally to be safe.

-from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
+from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
+import copy
@@
-# Read Default Config for EAGLE3
-config = EAGLE3_DEFAULT_CFG["config"]
+# Read Default Config for EAGLE3 (local copy)
+config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])
🧹 Nitpick comments (2)
examples/speculative_decoding/example.ipynb (2)

375-387: Pass HF token into SGLang container and pre‑remove any stale container.

-container_name = "sglang_serve_spec"
+container_name = "sglang_serve_spec"
+subprocess.call(["docker","rm","-f",container_name])
@@
-    "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
+    "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
+    "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
+    "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",

93-94: Import OK — eagle_utils present at examples/speculative_decoding/eagle_utils.py

Defines LazySupervisedDataset and DataCollatorWithPadding, so "from eagle_utils import ..." works when example.ipynb is run from that directory. If the notebook will be executed from other CWDs or the module reused elsewhere, either add examples/speculative_decoding to PYTHONPATH, convert the folder to a package and use a relative import, or install the module.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9436a44 and f46092b.

📒 Files selected for processing (1)
  • examples/speculative_decoding/example.ipynb (12 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/example.ipynb
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/example.ipynb
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/example.ipynb
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (5)
examples/speculative_decoding/example.ipynb (5)

147-153: LGTM: HF Unified checkpoint export is correctly invoked.


318-320: LGTM: Cleanup step present to remove TRT‑LLM container.


44-47: Access to meta‑llama models often requires an HF token.

If the base model isn’t in the mounted HF cache, from_pretrained will fail in containers. Ensure HF_TOKEN is set and the cache is mounted (see docker comments).


336-345: Confirmed: SGLang EAGLE3 CLI flags are correct.
--speculative-algorithm EAGLE3, --speculative-draft-model-path, and --speculative-num-draft-tokens are valid/current for sglang.launch_server; optional tuning flags include --speculative-num-steps and --speculative-eagle-topk.


114-121: Incorrect — EAGLE conversion/config already handles freezing the base model.

The conversion uses eagle_freeze_base_model (default True) and passes it into the converted model which sets base params' requires_grad=False; the EAGLE head lives under model.eagle_module (e.g. eagle_module.eagle_lm_head, eagle_module.fc, eagle_module.decoder.layers.*). See modelopt/torch/speculative/config.py, modelopt/torch/speculative/eagle/conversion.py, modelopt/torch/speculative/plugins/megatron_eagle.py and modelopt/torch/speculative/plugins/transformers.py for the freeze logic.

Likely an incorrect or invalid review comment.

Signed-off-by: h-guo18 <[email protected]>

Updated example.ipynb with sglang steps

Updated haoguo's demo notebook with sglang instructions

fixed --model-path to --model

Commit with sign.

Signed-off-by: Jamie Li <[email protected]>
@jamieliNVIDIA jamieliNVIDIA force-pushed the update-eagle-notebook-example-with-sglang branch from f6b85a9 to f85794e Compare September 15, 2025 19:30
Copy link

codecov bot commented Sep 15, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.88%. Comparing base (669ae05) to head (f85794e).
⚠️ Report is 7 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #316   +/-   ##
=======================================
  Coverage   73.87%   73.88%           
=======================================
  Files         172      172           
  Lines       17415    17444   +29     
=======================================
+ Hits        12865    12888   +23     
- Misses       4550     4556    +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (9)
examples/speculative_decoding/example.ipynb (9)

17-18: Avoid git‑LFS headaches: use huggingface_hub instead of git clone.

Replace shell clone with snapshot_download for reliability and caching.

-!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater
+from huggingface_hub import snapshot_download
+snapshot_download(
+    repo_id="nvidia/Daring-Anteater",
+    repo_type="dataset",
+    local_dir="/tmp/Daring-Anteater",
+    local_dir_use_symlinks=False,
+)

64-72: Align tokenizer limits to the model and set padding side explicitly.

-# Prepare Tokenizer
-tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)
-tokenizer.pad_token_id = tokenizer.eos_token_id
+# Prepare Tokenizer
+cfg = transformers.AutoConfig.from_pretrained(base_model)
+max_len = int(getattr(cfg, "max_position_embeddings", 2048))
+tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=max_len)
+tokenizer.pad_token_id = tokenizer.eos_token_id
+tokenizer.padding_side = "right"
 if tokenizer.chat_template is None:

95-98: Avoid loading the entire JSONL into memory; stream or use datasets.

-with open("/tmp/Daring-Anteater/train.jsonl") as f:
-    data_json = [json.loads(line) for line in f]
-train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
-eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)
+def stream_jsonl(path):
+    with open(path) as f:
+        for line in f:
+            yield json.loads(line)
+all_data = list(stream_jsonl("/tmp/Daring-Anteater/train.jsonl"))
+split = int(len(all_data) * 0.95)
+train_dataset = LazySupervisedDataset(all_data[:split], tokenizer=tokenizer)
+eval_dataset = LazySupervisedDataset(all_data[split:], tokenizer=tokenizer)

(Or switch to datasets.load_dataset("json") with train_test_split.)


101-112: Don’t subclass TrainingArguments for minor toggles.

-from dataclasses import dataclass, field
+from dataclasses import dataclass, field  # (no longer needed if not subclassing)
@@
-@dataclass
-class TrainingArguments(transformers.TrainingArguments):
-    dataloader_drop_last: bool = field(default=True)
-    bf16: bool = field(default=True)
+TrainingArguments = transformers.TrainingArguments
@@
 training_args = TrainingArguments(
     output_dir="/tmp/eagle_bf16",
     num_train_epochs=4,
     per_device_train_batch_size=1,
     per_device_eval_batch_size=1,
+    bf16=True,
+    dataloader_drop_last=True,
 )

121-121: Remove private API call Trainer._move_model_to_device().

This is unstable across HF releases; let Trainer handle device placement.

-trainer._move_model_to_device(model, trainer.args.device)
+# Let Trainer manage device placement

172-180: Bind TRT‑LLM token/sequence limits to the model.

 trtllm_serve_script = f"""trtllm-serve {base_model} \
@@
-    --max_num_tokens 8192 \
-    --max_seq_len 8192 \
+    --max_num_tokens {tokenizer.model_max_length} \
+    --max_seq_len {tokenizer.model_max_length} \

219-246: Container ergonomics: mount HF cache/token, bump shm, and pre‑remove old container.

 import subprocess
 import threading
+import os
@@
-# Generate a unique container name so we can stop/remove it later
+# Generate a unique container name so we can stop/remove it later
 container_name = "trtllm_serve_spec"
+home_dir = os.path.expanduser("~")
+hf_cache_dir = os.path.join(home_dir, ".cache", "huggingface")
+os.makedirs(hf_cache_dir, exist_ok=True)
+subprocess.call(["docker", "rm", "-f", container_name])

 docker_cmd = [
     "docker",
     "run",
     "--rm",
     "--net",
     "host",
-    "--shm-size=2g",
+    "--shm-size=32g",
     "--ulimit",
     "memlock=-1",
     "--ulimit",
     "stack=67108864",
+    "--ipc=host",
     "--gpus",
     "all",
-    "-v",
-    "/tmp:/tmp",
+    "-v", "/tmp:/tmp",
+    "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
+    "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
+    "-e", "HF_HUB_ENABLE_HF_TRANSFER=1",
+    # optionally: "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",
     "--name",
     container_name,
     "nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2",
     "bash",
     "-c",
     "bash /tmp/trtllm_serve.sh",
 ]

281-301: Harden the TRT‑LLM request; remove non‑standard fields.

 payload = {
     "model": base_model,
@@
-    "temperature": 0,
-    "chat_template": tokenizer.chat_template,
+    "temperature": 0,
 }
 headers = {"Content-Type": "application/json", "Accept": "application/json"}
 
-response = requests.post(
-    "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
-
-print(output)
+resp = requests.post(
+    "http://localhost:8000/v1/chat/completions",
+    headers=headers,
+    json=payload,
+    timeout=60,
+)
+resp.raise_for_status()
+print(resp.json())

421-442: Add basic error handling on the SGLang request.

-#Send request to the SGLang server
-response = requests.post(
-    "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
-
-print(output)
+# Send request to the SGLang server
+resp = requests.post(
+    "http://localhost:30000/v1/chat/completions",
+    headers=headers,
+    json=payload,
+    timeout=60,
+)
+resp.raise_for_status()
+print(resp.json())
🧹 Nitpick comments (5)
examples/speculative_decoding/example.ipynb (5)

34-39: Don’t mutate imported defaults; work on a copy of EAGLE3 config.

Mutating EAGLE3_DEFAULT_CFG in place can leak across cells/runs. Use a local deepcopy.

Using your retrieved learning: tests may not need extra deepcopy, but in a long‑lived notebook kernel avoiding global mutation is safer.

 import transformers
+import copy
 
 import modelopt.torch.opt as mto
 import modelopt.torch.speculative as mtsp
 from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
@@
-# Read Default Config for EAGLE3
-config = EAGLE3_DEFAULT_CFG["config"]
+# Read Default Config for EAGLE3 (local copy to avoid global mutation)
+config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])

Also applies to: 48-59


43-46: Note: base model may require an HF token.

Llama 3.2 repos often need auth. Add a short note or check for HF_TOKEN/env login to reduce friction.

 # Load original HF model
 base_model = "meta-llama/Llama-3.2-1B"
+import os
+if not os.environ.get("HF_TOKEN"):
+    print("Hint: this model may require Hugging Face auth (HF_TOKEN or `huggingface-cli login`).")
 model = transformers.AutoModelForCausalLM.from_pretrained(
     base_model, torch_dtype="auto", device_map="cuda"
 )

362-405: SGLang container: ensure HF cache env, optional token, and cleanup.

 import subprocess
 import threading
 import os
@@
 container_name = "sglang_serve_spec"
@@
-docker_cmd = [
+subprocess.call(["docker", "rm", "-f", container_name])
+docker_cmd = [
     "docker", "run",
     "--rm",
     "--net", "host",
     "--shm-size=32g",
     "--gpus", "all",
     "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
     "-v", "/tmp:/tmp",
     "--ipc=host",
+    "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
+    "-e", "HF_HUB_ENABLE_HF_TRANSFER=1",
+    # optionally: "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",
     "--name", container_name,
-    "lmsysorg/sglang:latest",
+    "lmsysorg/sglang:latest",
     "bash", "-c", "bash /tmp/sglang_serve.sh",
 ]

Also consider pinning the image tag instead of “latest” for reproducibility.


271-273: Optional: add readiness polling instead of manual log watching.

Example snippet to insert before sending requests:

import time, requests
def wait_ready(url, timeout=120):
    t0=time.time()
    while time.time()-t0<timeout:
        try:
            r=requests.get(url, timeout=2)
            if r.status_code in (200,404): return True
        except Exception: pass
        time.sleep(2)
    raise TimeoutError(f"Service not ready: {url}")

wait_ready("http://localhost:8000")   # TRT‑LLM
# wait_ready("http://localhost:30000") # SGLang

Also applies to: 411-413


113-120: Verify only the EAGLE head is trainable.

Sanity‑check that base model params are frozen post‑convert; prevents unintended fine‑tuning.

# Diagnostic check (optional cell)
trainable = [n for n,p in model.named_parameters() if p.requires_grad]
print(f"Trainable params: {len(trainable)}")
assert all("eagle" in n.lower() for n in trainable), "Unexpected trainable params outside EAGLE head."
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f85794e and 7540916.

📒 Files selected for processing (1)
  • examples/speculative_decoding/example.ipynb (12 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/example.ipynb
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/example.ipynb
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/example.ipynb
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Comment on lines +61 to 62
"# Convert Model for eagle speculative decoding\n",
"mtsp.convert(model, [(\"eagle\", config)])\n",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify algorithm naming is consistent across ModelOpt, TRT‑LLM, and SGLang.

mtsp.convert uses "eagle" while TRT‑LLM yaml says "Eagle" and SGLang flag uses "EAGLE3". Confirm accepted identifiers/case to avoid runtime mismatches.

Also applies to: 189-191, 339-342


🌐 Web query:

Check:
1) modelopt.torch.speculative mtsp.convert adapter names (is "eagle" correct for EAGLE3?).
2) TRT-LLM extra_llm_api_config 'speculative_config.decoding_type' accepted values and case ("Eagle" vs "EAGLE3").
3) SGLang --speculative-algorithm accepted values for EAGLE3 and required draft model flags.

💡 Result:

Answers — concise:

  1. modelopt.torch.speculative mtsp.convert adapter name
  • Use "eagle" (lowercase) as the mode string for mtsp.convert. [1][2]
  1. TRT-LLM extra_llm_api_config / speculative_config.decoding_type (case/values)
  • TRT-LLM uses lowercase mode names; the speculative/decoding mode values include (examples) "eagle", "medusa", "lookahead", "redrafter", "top_k", "top_p", "top_k_top_p", etc. Use "eagle" (lowercase) for EAGLE-style speculative decoding. (Note: EAGLE‑2 support is limited/not supported in some Triton/TensorRT-LLM deployments per docs.) [3][4]
  1. SGLang --speculative-algorithm accepted values and draft-model flags
  • SGLang accepts "EAGLE3" (capitalization as shown) for the flag --speculative-algorithm (it documents EAGLE3 explicitly). SGLang requires a draft model path (speculative_draft_model_path / --speculative-draft-model-path) when using EAGLE/EAGLE3; other EAGLE-related flags include speculative_num_steps, speculative_eagle_topk, speculative_num_draft_tokens, etc. [5][6]

Sources

  • [1] NVIDIA TensorRT-Model-Optimizer speculative_decoding API (mtsp.convert modes include "eagle"). (nvidia.github.io)
  • [2] TensorRT-Model-Optimizer examples / eagle module docs. (github.com)
  • [3] TensorRT-LLM / trtllm-build and model_config docs (speculative_decoding_mode / decoding_mode values; "eagle" lowercase). (nvidia.github.io)
  • [4] TensorRT-LLM speculative-decoding tutorial / notes (EAGLE details / EAGLE-2 support notes). (docs.nvidia.com)
  • [5] SGLang docs: backend/advanced pages show --speculative-algorithm EAGLE3 and required speculative_draft_model_path. (docs.sglang.ai)
  • [6] Hugging Face SGLang EAGLE3 draft-model examples (usage showing --speculative-algorithm EAGLE3 + draft model). (huggingface.co)

If you want, I can paste exact example command lines from the docs for each case.

Citations:


Align speculative-decoding names: use "eagle" for ModelOpt/TRT‑LLM; keep "EAGLE3" for SGLang and provide a draft model.

  • examples/speculative_decoding/example.ipynb (61–62; also 189–191, 339–342): mtsp.convert already uses "eagle" — keep as-is.
  • TRT‑LLM YAML/config: change any "Eagle" → "eagle" (search for speculative_config.decoding_type / decoding_type / speculative_decoding_mode) to match TRT‑LLM/ModelOpt accepted lowercase values.
  • SGLang: continue using "--speculative-algorithm EAGLE3" and ensure a --speculative-draft-model-path / speculative_draft_model_path is supplied.

Add a one-line mapping in the example/README explaining that SGLang uses "EAGLE3" while ModelOpt/TRT‑LLM use "eagle" to avoid confusion.

🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 61-62 (and note
similar usages at 189-191 and 339-342), the notebook already uses
mtsp.convert(model, [("eagle", config)]) which is correct for ModelOpt/TRT‑LLM;
ensure any TRT‑LLM YAML/config keys that reference the speculative decoding type
use lowercase "eagle" (search for speculative_config.decoding_type,
decoding_type, speculative_decoding_mode and change "Eagle" → "eagle"); for
SGLang keep the CLI flag value as "EAGLE3" and ensure the SGLang invocation
provides a draft model via --speculative-draft-model-path (or
speculative_draft_model_path in configs); finally add a one-line note in the
example README mapping that SGLang uses "EAGLE3" while ModelOpt/TRT‑LLM use
"eagle" to avoid confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants