-
Notifications
You must be signed in to change notification settings - Fork 161
update eagle example notebook #314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Warning Rate limit exceeded@h-guo18 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 8 minutes and 40 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (1)
WalkthroughThe notebook examples/speculative_decoding/example.ipynb was rewritten into an end-to-end speculative-decoding pipeline: load and adapt meta-llama/Llama-3.2-1B into an EAGLE3 draft/head, train the draft with HF Trainer, export a unified HF checkpoint, and add container-based deployment scripts for TRT-LLM and SGLang; TrainingArguments fields removed. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor U as User
participant NB as Notebook
participant HF as HF Model/Tokenizer
participant DS as Daring-Anteater Dataset
participant TR as HF Trainer
participant CK as HF Checkpoint (/tmp/hf_ckpt)
U->>NB: Run notebook
NB->>HF: Load Llama-3.2-1B + tokenizer
NB->>HF: Adapt model for EAGLE3 (draft head, config sync)
NB->>DS: Load and split dataset (train/eval)
NB->>TR: Configure Trainer (4 epochs, no label_smoother)
TR->>HF: Train draft head
TR-->>NB: Save model/tokenizer
NB->>CK: Export unified HF checkpoint
note over NB,CK: Export replaces prior quantization flow
sequenceDiagram
autonumber
actor U as User
participant NB as Notebook
participant TK as /tmp/trtllm_serve.sh
participant SK as /tmp/sglang_serve.sh
participant T as TRT-LLM Server (Docker)
participant S as SGLang Server (Docker)
U->>NB: Execute deployment cells
NB->>TK: Generate TRT-LLM script with speculative_config & kv_cache_config
NB->>T: Launch container (background), stream logs
NB->>T: Send test inference request
NB->>SK: Generate SGLang serve script (speculative settings)
NB->>S: Launch SGLang container
note over T,S: vLLM: placeholder (coming soon)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
dfe0b26
to
de07831
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #314 +/- ##
=======================================
Coverage 73.82% 73.82%
=======================================
Files 172 172
Lines 17438 17438
=======================================
+ Hits 12873 12874 +1
+ Misses 4565 4564 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
de07831
to
293d659
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (8)
examples/speculative_decoding/example.ipynb (8)
17-18
: Prefer snapshot_download over git clone for HF datasets (avoids git‑lfs issues).
git clone
of HF datasets often fails without git‑lfs. Use huggingface_hub.snapshot_download for reliability.-!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater +from huggingface_hub import snapshot_download +snapshot_download(repo_id="nvidia/Daring-Anteater", repo_type="dataset", local_dir="/tmp/Daring-Anteater", local_dir_use_symlinks=False, ignore_patterns=["*.md","*.png",".gitattributes",".git/*"])
65-72
: Chat template fallback may not match Llama‑3.x format.Your fallback uses
<|im_start|>
which differs from Llama‑3.x templates. Keep only if truly None; otherwise you risk mismatched formatting. Consider loading the template from the HF repo/config instead of hardcoding.
101-112
: Make bf16 selection hardware‑aware; add a seed.Hard‑coding
bf16=True
fails on GPUs without BF16. Also set a seed for reproducibility.-from dataclasses import dataclass, field +from dataclasses import dataclass, field +import torch +transformers.set_seed(42) @@ class TrainingArguments(transformers.TrainingArguments): - dataloader_drop_last: bool = field(default=True) - bf16: bool = field(default=True) + dataloader_drop_last: bool = field(default=True) + bf16: bool = field(default=(torch.cuda.is_available() and torch.cuda.is_bf16_supported()))
129-130
: Save tokenizer with the exported checkpoint (used by servers).You save the tokenizer to the training dir but not to the exported HF ckpt used by serving.
-tokenizer.save_pretrained(training_args.output_dir) +tokenizer.save_pretrained(training_args.output_dir) +tokenizer.save_pretrained("/tmp/hf_ckpt") # keep alongside exported model for serving
172-180
: Derive server max lengths from model config; bind to localhost.Avoid hard‑coding 8192 and exposing on 0.0.0.0. Use
model.config.max_position_embeddings
and bind to 127.0.0.1 for local demos.-context_len = 8192 +context_len = int(getattr(model.config, "max_position_embeddings", 8192)) -trtllm_serve_script = f"""trtllm-serve {base_model} \ - --host 0.0.0.0 \ +trtllm_serve_script = f"""trtllm-serve {base_model} \ + --host 127.0.0.1 \ --port 8000 \ @@ - --max_num_tokens 8192 \ - --max_seq_len 8192 \ + --max_num_tokens {context_len} \ + --max_seq_len {context_len} \ --extra_llm_api_options /tmp/extra-llm-api-config.yml """
283-305
: Harden the client request (timeouts + errors).Add a short timeout and raise for non‑200 responses to avoid hanging cells.
-response = requests.post( - "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload) -) -output = response.json() +response = requests.post( + "http://127.0.0.1:8000/v1/chat/completions", + headers=headers, + data=json.dumps(payload), + timeout=30, +) +response.raise_for_status() +output = response.json()
337-350
: Bind SGLang server to localhost; consider pinning dtype to bf16 if trained as bf16.Match serving dtype to training where possible, and avoid exposing 0.0.0.0 for local demos.
-sglang_serve_script = f"""python3 -m sglang.launch_server \ - --model {base_model} \ - --host 0.0.0.0 \ +sglang_serve_script = f"""python3 -m sglang.launch_server \ + --model {base_model} \ + --host 127.0.0.1 \ @@ - --dtype float16 + --dtype bfloat16 """If BF16 isn’t supported, fallback to
float16
.
438-460
: Add timeout and error checks for SGLang request.-response = requests.post( - "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload) -) -output = response.json() +response = requests.post( + "http://127.0.0.1:30000/v1/chat/completions", + headers=headers, + data=json.dumps(payload), + timeout=30, +) +response.raise_for_status() +output = response.json()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/speculative_decoding/example.ipynb
(12 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: build-docs
🔇 Additional comments (6)
examples/speculative_decoding/example.ipynb (6)
148-152
: LGTM on export path.Exporting to a unified HF checkpoint is correct and matches the serving steps.
159-164
: LGTM on TRT‑LLM deployment narrative.Clear, usable steps with config dumping before container launch.
327-349
: LGTM on SGLang configuration options.Flags align with EAGLE‑style speculative decoding; minor host/dtype notes addressed above.
189-193
: Resolved — use "Eagle" (capital E)Confirmed: trtllm-serve expects decoding_type = "Eagle"; the example already uses this value, no change required.
380-401
: Replace host networking with port mapping; pin Docker image to lmsysorg/sglang:v0.4.9.post4.Host networking is unsafe; v0.4.9.post4 is the stable tag that includes EAGLE‑3.
File: examples/speculative_decoding/example.ipynb (lines 380–401)
- "--net", - "host", + "-p","30000:30000", @@ - "lmsysorg/sglang:latest", + "lmsysorg/sglang:v0.4.9.post4",
92-92
: Import valid — eagle_utils defines the symbols used.examples/speculative_decoding/eagle_utils.py defines LazySupervisedDataset (line 139) and DataCollatorWithPadding (line 207); the notebook import is valid.
Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
examples/speculative_decoding/example.ipynb (4)
17-18
: Don’t hardcode dataset path or rely on git clone; load via datasets and add a guard.git-lfs may be missing and /tmp paths vary. Use datasets API and/or parameterize the path with an existence check.
Option A — datasets API (preferred):
-!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater +from datasets import load_dataset +ds = load_dataset("nvidia/Daring-Anteater", split="train") +data_json = [x for x in ds]Then replace file I/O:
-with open("/tmp/Daring-Anteater/train.jsonl") as f: - data_json = [json.loads(line) for line in f] +# data_json already loaded above; keep downstream code as-isOption B — parameterize + guard:
+from pathlib import Path +DATASET_PATH = Path(os.getenv("EAGLE_DATASET_PATH", "/tmp/Daring-Anteater/train.jsonl")) +if not DATASET_PATH.exists(): + raise FileNotFoundError( + f"Dataset not found at {DATASET_PATH}. " + "Set EAGLE_DATASET_PATH or install `datasets` and use load_dataset('nvidia/Daring-Anteater')." + ) -with open("/tmp/Daring-Anteater/train.jsonl") as f: +with open(DATASET_PATH, encoding="utf-8") as f: data_json = [json.loads(line) for line in f]Also applies to: 95-99
44-46
: Remove device_map and the private _move_model_to_device; let Trainer handle placement.Specifying device_map with Trainer and calling a private Trainer method is brittle and can misplace params. Let Trainer place the model (or call model.to(...) once before Trainer).
Apply:
- model = transformers.AutoModelForCausalLM.from_pretrained( - base_model, torch_dtype="auto", device_map="cuda" - ) + model = transformers.AutoModelForCausalLM.from_pretrained( + base_model, torch_dtype="auto" + )And:
- trainer._move_model_to_device(model, trainer.args.device) + # Trainer handles device placement.Also applies to: 121-121
219-246
: Avoid --net host; mount HF cache for gated models; pin/verify TRT‑LLM image.Use explicit port mapping and mount ~/.cache/huggingface so downloads work inside the container. rc image may be outdated; prefer a stable tag.
-import subprocess -import threading +import subprocess +import threading +import os @@ -container_name = "trtllm_serve_spec" +container_name = "trtllm_serve_spec" +home_dir = os.path.expanduser("~") +hf_cache_dir = os.path.join(home_dir, ".cache", "huggingface") +os.makedirs(hf_cache_dir, exist_ok=True) @@ - "--net", - "host", + "-p","8000:8000", @@ "-v", "/tmp:/tmp", + "-v", + f"{hf_cache_dir}:/root/.cache/huggingface", + "-e","HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface", @@ - "nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2", + "nvcr.io/nvidia/tensorrt-llm/release:<stable-tag>",To confirm the latest stable TRT‑LLM container and trtllm-serve availability:
What is the latest stable version of the NVIDIA TensorRT-LLM release container, and does it include `trtllm-serve` with speculative decoding/EAGLE support?
384-401
: Drop --net host and pin SGLang image; map port explicitly.Host networking is unnecessary here and reduces safety; also avoid the floating latest tag.
- "--net", - "host", + "-p","30000:30000", @@ - "lmsysorg/sglang:latest", + "lmsysorg/sglang:<pinned-version>",
🧹 Nitpick comments (3)
examples/speculative_decoding/example.ipynb (3)
101-112
: Gate bf16 by capability; fallback to fp16 for wider compatibility.Hard‑enabling bf16 can break on many GPUs. Auto‑detect and set fp16 when bf16 is unsupported.
-@dataclass -class TrainingArguments(transformers.TrainingArguments): - dataloader_drop_last: bool = field(default=True) - bf16: bool = field(default=True) +@dataclass +class TrainingArguments(transformers.TrainingArguments): + dataloader_drop_last: bool = field(default=True) + bf16: bool = field(default=(torch.cuda.is_available() and torch.cuda.is_bf16_supported())) training_args = TrainingArguments( output_dir="/tmp/eagle_bf16", - num_train_epochs=4, + num_train_epochs=4, per_device_train_batch_size=1, per_device_eval_batch_size=1, + fp16=not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()), )
248-267
: Optional: poll readiness instead of manual wait.Poll the HTTP endpoint before sending requests to reduce flakiness.
print( f"Starting trtllm-serve in Docker (PID: {proc.pid}, container name: {container_name}) in the background:" ) +import time, requests as _r +for _ in range(120): + try: + if _r.get("http://localhost:8000/v1/models", timeout=2).status_code == 200: + break + except Exception: + time.sleep(1)
442-451
: Include chat_template in SGLang request for consistency.You pass chat_template to TRT‑LLM but not to SGLang; add it to ensure identical prompts.
payload = { "model": base_model, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Tell me about speculative decoding."}, ], "max_tokens": 512, "temperature": 0, + "chat_template": tokenizer.chat_template, }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/speculative_decoding/example.ipynb
(12 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (4)
examples/speculative_decoding/example.ipynb (4)
61-63
: Ensure only the EAGLE/draft module is trainable; freeze base model params.If convert() doesn’t freeze base weights, Trainer will update the full model. Verify and freeze non‑EAGLE params.
Minimal guard (after convert):
mtsp.convert(model, [("eagle", config)]) +for name, p in model.named_parameters(): + p.requires_grad = ("eagle" in name) +trainable = [n for n, p in model.named_parameters() if p.requires_grad] +print(f"Trainable params: {len(trainable)} -> {trainable[:5]}...")Please confirm trainable names are limited to the EAGLE module.
Also applies to: 113-120
64-72
: Tokenizer prep LGTM.Pad token tied to EOS and conditional chat_template initialization look good.
148-153
: Export step LGTM.eval() before export and using export_hf_checkpoint to /tmp/hf_ckpt are appropriate.
172-196
: Config content looks reasonable for an initial EAGLE trial.Speculative config, kv-cache, and conservative cuda-graph max bs are sensible defaults for a demo.
If you plan larger batches later, consider revisiting max_batch_size and cuda_graph_config.
83e06df
to
7035565
Compare
Signed-off-by: h-guo18 <[email protected]>
7035565
to
a2b0681
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (6)
examples/speculative_decoding/example.ipynb (6)
95-99
: Do not hardcode dataset path; add guard or parameterize.Make path configurable and fail clearly if missing.
Apply:
-import json -from dataclasses import dataclass, field - -from eagle_utils import DataCollatorWithPadding, LazySupervisedDataset -from transformers import Trainer - -with open("/tmp/Daring-Anteater/train.jsonl") as f: +import json, os +from pathlib import Path +from dataclasses import dataclass, field +from eagle_utils import DataCollatorWithPadding, LazySupervisedDataset +from transformers import Trainer + +data_path = Path(os.getenv("DA_TRAIN_PATH", "/tmp/Daring-Anteater/train.jsonl")) +if not data_path.exists(): + raise FileNotFoundError(f"Dataset not found at {data_path}. Set DA_TRAIN_PATH or run the download cell.") +with open(data_path) as f: data_json = [json.loads(line) for line in f]If adopting datasets.load_dataset above, use ds instead of opening a file.
113-121
: Don’t call private Trainer methods.Remove _move_model_to_device; let Trainer handle placement (after removing device_map).
Apply:
trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=DataCollatorWithPadding(), ) -trainer._move_model_to_device(model, trainer.args.device)
45-46
: Remove device_map when using Trainer (let Trainer place the model).This conflicts with Trainer and led you to call a private API later.
Apply:
-model = transformers.AutoModelForCausalLM.from_pretrained( - base_model, torch_dtype="auto", device_map="cuda" -) +model = transformers.AutoModelForCausalLM.from_pretrained( + base_model, torch_dtype="auto" +)
219-246
: Avoid --net=host; mount HF cache for gated models.Use explicit -p and mount ~/.cache/huggingface so the container sees previously downloaded weights.
Apply:
-import subprocess -import threading - -# Generate a unique container name so we can stop/remove it later -container_name = "trtllm_serve_spec" +import subprocess, threading, os, uuid +container_name = f"trtllm_serve_spec-{uuid.uuid4().hex[:8]}" +home_dir = os.path.expanduser("~") +hf_cache_dir = os.path.join(home_dir, ".cache", "huggingface") +os.makedirs(hf_cache_dir, exist_ok=True) docker_cmd = [ "docker", "run", "--rm", - "--net", - "host", + "-p","8000:8000", "--shm-size=2g", "--ulimit", "memlock=-1", "--ulimit", "stack=67108864", "--gpus", "all", "-v", "/tmp:/tmp", + "-v", f"{hf_cache_dir}:/root/.cache/huggingface", + "-e","HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface", "--name", container_name, "nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2", "bash", "-c", "bash /tmp/trtllm_serve.sh", ]
369-401
: SGLang: avoid --net=host; publish port explicitly.Mirror the TRT‑LLM fix and keep HF cache mount.
Apply:
- "--net", - "host", + "-p","30000:30000",
61-62
: Freeze base model; train only the draft (“eagle”) module as stated.Currently all params appear trainable; freeze non‑EAGLE params to match the doc text and reduce memory/compute.
Apply:
mtsp.convert(model, [("eagle", config)]) + +# Freeze base model; train only EAGLE params +for name, p in model.named_parameters(): + p.requires_grad = ("eagle" in name) + +trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) +total = sum(p.numel() for p in model.parameters()) +print(f"Trainable params: {trainable:,} / {total:,}")
🧹 Nitpick comments (9)
examples/speculative_decoding/example.ipynb (9)
17-18
: Prefer datasets/snapshot_download over git clone to avoid LFS issues.Use HF Datasets or huggingface_hub to reliably fetch JSONL (works without git-lfs).
Apply:
-!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater +from datasets import load_dataset +ds = load_dataset("nvidia/Daring-Anteater", split="train") +# If you still want a local file: ds.to_json("/tmp/Daring-Anteater/train.jsonl", lines=True)
65-72
: Don’t rely on passing chat_template to servers; pre-render on client.Some servers ignore client-supplied templates. Render text with tokenizer.apply_chat_template and send a plain prompt.
Example:
prompt = tokenizer.apply_chat_template(payload["messages"], tokenize=False, add_generation_prompt=True)Then call the server API with a prompt/completions endpoint accordingly.
101-112
: bf16 may not be supported on all GPUs; add a safe fallback.Guard bf16 or expose dtype via arg/env.
Apply:
- bf16: bool = field(default=True) + bf16: bool = field(default=False) + fp16: bool = field(default=True)Or set at construction:
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 training_args = TrainingArguments(..., bf16=use_bf16, fp16=not use_bf16)
172-204
: TRT‑LLM: set conservative limits; keep extra config small.Minor: consider matching max_batch_size to expected interactive usage (e.g., 1–4) and avoid redundant max_num_tokens & max_seq_len both at 8192.
283-305
: Add timeout and error handling to requests.Prevent notebook hangs and surface server errors.
Apply:
-response = requests.post( - "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload) -) -output = response.json() +response = requests.post( + "http://localhost:8000/v1/chat/completions", + headers=headers, + data=json.dumps(payload), + timeout=60, +) +response.raise_for_status() +output = response.json()
320-321
: Use the same container_name variable for cleanup.The hardcoded name can drift if you randomize names as suggested above.
Apply:
-!docker rm -f trtllm_serve_spec +import subprocess +subprocess.run(["docker","rm","-f", container_name], check=False)
438-460
: Add timeout and error handling to SGLang client call.Same rationale as TRT‑LLM cell.
Apply:
-response = requests.post( - "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload) -) -output = response.json() +response = requests.post( + "http://localhost:30000/v1/chat/completions", + headers=headers, + data=json.dumps(payload), + timeout=60, +) +response.raise_for_status() +output = response.json()
475-476
: Use variable container_name for SGLang cleanup.Keep cleanup robust with dynamic names.
Apply:
-!docker rm -f sglang_serve_spec +subprocess.run(["docker","rm","-f", container_name], check=False)
495-510
: Kernel name nit: not vLLM‑only.Display name “modelopt+vllm” may mislead since the notebook targets TRT‑LLM and SGLang; consider a neutral name.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/speculative_decoding/example.ipynb
(12 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.961Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (4)
examples/speculative_decoding/example.ipynb (4)
129-130
: LGTM: saving model and tokenizer.Save paths are consistent with output_dir.
148-152
: LGTM: export_hf_checkpoint usage.Export step aligns with subsequent serving configs.
48-59
: Avoid mutating the imported default EAGLE3 config.You update a nested dict in-place; this can leak state across runs/imports.
Apply:
+import copy -# Read Default Config for EAGLE3 -config = EAGLE3_DEFAULT_CFG["config"] +# Read Default Config for EAGLE3 (deep copy to avoid global mutation) +config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])⛔ Skipped due to learnings
Learnt from: yeyu-nvidia PR: NVIDIA/TensorRT-Model-Optimizer#295 File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39 Timestamp: 2025-09-05T19:10:36.393Z Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Learnt from: yeyu-nvidia PR: NVIDIA/TensorRT-Model-Optimizer#295 File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39 Timestamp: 2025-09-05T19:10:36.393Z Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Learnt from: yeyu-nvidia PR: NVIDIA/TensorRT-Model-Optimizer#295 File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39 Timestamp: 2025-09-05T19:10:36.393Z Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
242-242
: Pin container tags and verify availability.examples/speculative_decoding/example.ipynb (lines 242, 396): replace nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2 and lmsysorg/sglang:latest with explicit, stable tags and confirm those image tags exist before publishing.
TensorRT‑LLM v1.1.0rc2 is published as a release. (github.com)
SGLang docs reference lmsysorg/sglang:latest — avoid :latest and pin a concrete release/tag. (docs.sglang.ai)I could not run the provided skopeo check in the sandbox (skopeo not installed); run the original skopeo commands locally (or docker pull the pinned tags) to verify image availability.
Signed-off-by: h-guo18 <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: new example
Overview: This PR update existing notebook example of speculative decoding for the purpose of blog post. Main changes include:
Usage
See instruction in example.ipynb
Testing
Tested running the notebook with Llama3.2-1B;
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Chores