Skip to content

Conversation

h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Sep 5, 2025

What does this PR do?

Type of change: New feature, New example;

Overview:
Add necessary changes and new features for updating eagle example, including:

Feat: add export support for speculative decoding models in hf_unified_export;
Update several arguments names and default values for example simplicity;
Added a few new scripts and renamed some files for the example;
Rewrote the README:

Rearrange content order; Introduce a "simplified workflow" section;
Provided more details in the "Complete workflow" section.
Removed deprecated contents: Nemo link and notebook example;

Usage

See README.md for usage.

# Add a code snippet demonstrating how to use this

Testing

Tested dummy training + ar_validate + export with:

  • Llama3.2-1B
  • Qwen
  • Mixtral
  • Phi-3

Tested deployment on:

  • trtllm-serve
  • SGLang

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • End-to-end train→validate→export workflow with one-line orchestrator, HF checkpoint export supporting speculative-decoding/Eagle format, optional W&B logging, and a new training launch script.
  • Documentation

    • README reorganized into simplified and complete workflows, setup/prereqs, SLURM data-prep guide, dataset instructions, and expanded deployment guidance (TRT-LLM, SGLang, quantized).
  • Changes

    • AR validation and CLI defaults adjusted (more samples/steps, chat enabled by default), draft-vocab sizing moved to explicit CLI, Eagle config updated (rope-scaling, init range).
  • Bug Fixes

    • Centralized draft-vocab mapping, device-alignment fixes, and per-step training accuracy reported in outputs.

@h-guo18 h-guo18 requested a review from yeyu-nvidia September 5, 2025 17:18
Copy link

copy-pr-bot bot commented Sep 5, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 5, 2025

Walkthrough

Reworks the speculative-decoding example into a ModelOpt / Hugging Face–centric end-to-end pipeline: README overhaul, new training/export scripts and SLURM data guide, CLI/calibration/eagle_config changes, HF export plugin and unified export integration, Eagle forward/training logic extensions, and minor telemetry/server tweaks.

Changes

Cohort / File(s) Summary
README & SLURM guide
examples/speculative_decoding/README.md, examples/speculative_decoding/SLURM_prepare_data.md
Major README rewrite to a runnable ModelOpt + HF workflow; added sections on training, export, deployment, optional data synthesis and vocab calibration; new SLURM data-prep doc.
Train orchestration scripts
examples/speculative_decoding/launch_train.sh, examples/speculative_decoding/train_eagle3_and_export.sh
New/updated shell scripts for launching speculative training and a 3‑stage train→validate→export orchestration with CLI parsing, computed defaults, and timestamped output dirs.
Validation & server CLI
examples/speculative_decoding/ar_validate.py, examples/speculative_decoding/server_generate.py
validate_ar default num_samples increased to 80; CLI defaults changed (steps→3, osl adjusted); --chat changed from store_true to --chat bool defaulting to True.
Calibration & eagle config
examples/speculative_decoding/calibrate_draft_vocab.py, examples/speculative_decoding/eagle_config.json
Replaced JSON draft_vocab lookup with required CLI --draft_vocab_size; removed draft_vocab_size from eagle_config.json; added rope_scaling and initializer_range entries.
HF export script
examples/speculative_decoding/export_hf_checkpoint.py
New script enabling ModelOpt HF checkpointing and exporting HF-compatible checkpoint artifacts to an export directory.
Export plugins & unified export
modelopt/torch/export/plugins/__init__.py, modelopt/torch/export/plugins/hf_spec_export.py, modelopt/torch/export/unified_export_hf.py
New hf_spec_export plugin (rename/prune and official-format config generation) and re-export; unified export integrates spec-decoding post-processing, quant-config normalization/persistence, and extends export_hf_checkpoint signature.
Eagle model / training logic
modelopt/torch/speculative/plugins/transformers.py
Removed temporary config overrides; _base_model_forward accepts **kwargs; added _map_logits_to_draft_vocab; support for base_model_outputs path; _eagle_loss now returns (regression, classification, accuracy); forward aggregates per-step accuracies into train_acc; device alignment fixes.
Runtime & telemetry
examples/speculative_decoding/main.py
Guarded optional wandb import/conditional logging; Eagle config propagation of max_position_embeddings from base model.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant TrainScript as train_eagle3_and_export.sh
  participant Launch as launch_train.sh
  participant Trainer as Trainer/ModelOpt
  participant Validator as ar_validate.py
  participant Exporter as export_hf_checkpoint.py
  participant Unified as unified_export_hf.py
  participant Plugin as hf_spec_export.py

  User->>TrainScript: run (base_model, num_gpu, data)
  TrainScript->>Launch: ./launch_train.sh (MODE=eagle3 ...)
  Launch->>Trainer: start training (Eagle architecture)
  Trainer-->>TrainScript: checkpoint → ckpts/<model>-<ts>
  TrainScript->>Validator: python ar_validate.py --model_path <ckpt>
  Validator-->>TrainScript: AR metrics (optionally → wandb)
  TrainScript->>Exporter: python export_hf_checkpoint.py --model_path <ckpt> --export_path <export>
  Exporter->>Unified: export_hf_checkpoint(model, export_dir)
  Unified->>Plugin: rename_and_prune_if_spec_decoding(model, state_dict)
  Unified->>Plugin: set_config_if_spec_decoding(model, config)
  Unified-->>Exporter: saved model/config (+quant cfg)
  Exporter-->>User: "Exported checkpoint to <export>"
Loading
sequenceDiagram
  autonumber
  participant HFModel as HFEagleModel
  participant Base as Base LLM
  participant Eagle as Eagle Module

  Note over HFModel: Forward supports optional base_model_outputs and draft-vocab mapping
  HFModel->>Base: optionally call forward / or accept provided base_model_outputs
  Base-->>HFModel: hidden_states / logits (or omitted)
  HFModel->>HFModel: _map_logits_to_draft_vocab(full_logits)  -- when draft_vocab != vocab
  HFModel->>Eagle: multi-step eagle forward (steps 0..3) with aux states
  Eagle-->>HFModel: eagle_logits per step
  HFModel->>HFModel: _eagle_loss -> (reg_loss, cls_loss, accuracy_k)
  HFModel-->>Caller: ModelOutput(..., train_acc=(acc0..acc3), losses...)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Suggested reviewers

  • ChenhanYu
  • yeyu-nvidia

Pre-merge checks (1 passed, 1 warning, 1 inconclusive)

❌ Failed Checks (1 warning, 1 inconclusive)
Check Name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 47.37% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ❓ Inconclusive The description provides a high-level overview and testing summary but omits detailed rationale and specifics for many of the new scripts, API changes, and plugin integrations, and the contributor checklist items remain as placeholders. Please expand the description with concrete details on each major change, fill in the checklist items with actual values, and consider adding a pull request template to ensure consistency and completeness in future submissions.
✅ Passed Checks (1 passed)
Check Name Status Explanation
Title Check ✅ Passed The title concisely and accurately highlights the primary changes—updating the eagle3 example and adding export support—without extraneous file lists or vague phrasing.

Poem

I trained a draft at dawn's first light,
Mapped tokens clever, pruned them right,
Packaged Eagle's state, then closed the gate,
Exported checkpoints to meet their fate,
A hopping rabbit cheers the pipeline bright! 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
		  - name: "Undocumented Breaking Changes"
			  mode: "warning"
			  instructions: |
				  Flag potential breaking changes that are not documented:
				  1. Identify changes to public APIs/exports, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints (including removed/renamed items and changes to types, required params, return values, defaults, or behavior).
				  2. Ignore purely internal/private changes (e.g., code not exported from package entry points or marked internal).
				  3. Verify documentation exists: a "Breaking Change" section in the PR description and updates to CHANGELOG.md.
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch haoguo/update-eagle-readme

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-readme branch from e6334c4 to 47a0a50 Compare September 5, 2025 18:22
@h-guo18 h-guo18 marked this pull request as ready for review September 5, 2025 18:26
@h-guo18 h-guo18 requested review from a team as code owners September 5, 2025 18:26
@h-guo18 h-guo18 requested a review from sugunav14 September 5, 2025 18:26
Copy link

codecov bot commented Sep 5, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.93%. Comparing base (cf6f1d4) to head (854564e).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #293   +/-   ##
=======================================
  Coverage   73.93%   73.93%           
=======================================
  Files         172      172           
  Lines       17408    17408           
=======================================
  Hits        12871    12871           
  Misses       4537     4537           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/server_generate.py (1)

156-158: Fix exception handler referencing undefined prompt in chat mode.

In chat path, prompt is never defined; on errors this raises a new exception and hides the real one.

Apply one of these minimal fixes:

-    except Exception as e:
-        print(e)
-        print(prompt)
-        print("Failed to generate data")
+    except Exception as e:
+        print(e)
+        if "prompt" in locals():
+            print(prompt)
+        else:
+            print("prompt not set (chat mode)")
+        print("Failed to generate data")

Or define prompt = None at function start and print only if not None.

🧹 Nitpick comments (30)
modelopt/torch/export/plugins/hf_spec_export.py (4)

16-16: Fix typos in module docstring.

"Modifiy stated_dict" → "Modify state_dict".

-"""Modifiy stated_dict and config for exporting speculative decoding in official format."""
+"""Modify state_dict and config for exporting speculative decoding in official format."""

23-24: Remove or use unused SPECULATIVE_DECODING_MODES.

Currently unused; either hook it into the guards or drop it to avoid dead code.

-SPECULATIVE_DECODING_MODES = ["eagle", "medusa"]
+# Reserved for future multi-mode handling:
+# SPECULATIVE_DECODING_MODES = ["eagle", "medusa"]

25-44: Rename constant to fix typo and improve clarity.

EALGE_MODELOPT_TO_OFFICIAL → EAGLE_MODELOPT_TO_OFFICIAL; update references.

-EALGE_MODELOPT_TO_OFFICIAL = {
+EAGLE_MODELOPT_TO_OFFICIAL = {
@@
-    _check_state_dict_keys_match(model.eagle_module, EALGE_MODELOPT_TO_OFFICIAL["required"])
+    _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
@@
-        **EALGE_MODELOPT_TO_OFFICIAL["required"],
-        **EALGE_MODELOPT_TO_OFFICIAL["optional"],
+        **EAGLE_MODELOPT_TO_OFFICIAL["required"],
+        **EAGLE_MODELOPT_TO_OFFICIAL["optional"],

Also applies to: 65-66, 70-72


68-79: Avoid repeated state_dict() calls and use provided post_state_dict for fallback.

Slight perf/clarity win and safer fallback when lm_head lives only in the passed-in state dict.

-    export_state_dict = {}
-    for ours_key, export_key in {
-        **EALGE_MODELOPT_TO_OFFICIAL["required"],
-        **EALGE_MODELOPT_TO_OFFICIAL["optional"],
-    }.items():
-        if ours_key in model.eagle_module.state_dict():
-            export_state_dict[export_key] = model.eagle_module.state_dict()[ours_key]
+    module_state = model.eagle_module.state_dict()
+    export_state_dict = {}
+    for ours_key, export_key in {
+        **EAGLE_MODELOPT_TO_OFFICIAL["required"],
+        **EAGLE_MODELOPT_TO_OFFICIAL["optional"],
+    }.items():
+        if ours_key in module_state:
+            export_state_dict[export_key] = module_state[ours_key]
@@
-    if "eagle_lm_head.weight" not in model.eagle_module.state_dict():
-        export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
+    if "eagle_lm_head.weight" not in module_state:
+        if "lm_head.weight" in post_state_dict:
+            export_state_dict["lm_head.weight"] = post_state_dict["lm_head.weight"]
+        else:
+            # Fall back to model.state_dict() if needed
+            base_state = model.state_dict()
+            if "lm_head.weight" in base_state:
+                export_state_dict["lm_head.weight"] = base_state["lm_head.weight"]
+            else:
+                raise KeyError("lm_head.weight not found in post_state_dict or model.state_dict()")
modelopt/torch/speculative/plugins/transformers.py (4)

721-727: Align reverse_mapping device with logits to avoid cross-device indexing.

Safer if full_logits is on a different device than d2t.

-        reverse_mapping = (
-            torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device)
-            + self.eagle_module.d2t
-        )
+        reverse_mapping = (
+            torch.arange(len(self.eagle_module.d2t), device=full_logits.device, dtype=torch.long)
+            + self.eagle_module.d2t.to(full_logits.device)
+        )
         return full_logits[:, :, reverse_mapping]

856-858: Nit: fix comment typo.

"diabled" → "disabled".

-            # NOTE: diabled for now.
+            # NOTE: disabled for now.

1052-1058: Compute masked accuracy over valid positions only.

Current mean divides by B*T; use sum over mask / mask.sum() for accurate reporting when masks zero-out prefixes.

-        base_predict_tok = base_model_logits.argmax(dim=-1)
-        eagle_predict_tok = eagle_logits.argmax(dim=-1)
-        accuracy = (
-            (loss_mask[:, :, 0] * (base_predict_tok == eagle_predict_tok)).float().mean().item()
-        )
-        accuracy = round(accuracy, 3)
+        base_predict_tok = base_model_logits.argmax(dim=-1)
+        eagle_predict_tok = eagle_logits.argmax(dim=-1)
+        valid = loss_mask[:, :, 0].bool()
+        correct = (base_predict_tok == eagle_predict_tok) & valid
+        denom = valid.sum().clamp_min(1).float()
+        accuracy = round(correct.sum().float().div(denom).item(), 3)

716-717: Add explicit d2t buffer assertion before vocab remap
Although gating on draft_vocab_size != vocab_size is correct, insert an assertion (e.g. assert hasattr(self, "d2t"), "d2t buffer not initialized") immediately before calling _map_logits_to_draft_vocab to surface misconfigurations more clearly.

modelopt/torch/export/plugins/__init__.py (1)

18-24: Optional: make public surface explicit.

Consider defining all in hf_spec_export to avoid star-import drift.

 with import_plugin("transformers"):
-    from .hf_spec_export import *
+    from .hf_spec_export import *  # relies on hf_spec_export.__all__
examples/speculative_decoding/export_hf_checkpoint.py (2)

25-29: Make CLI args required and self-documenting

Avoid silent defaults. Require both paths and add help for better UX.

-def parse_args():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--model_path", type=str, default="")
-    parser.add_argument("--export_path", type=str, default="")
-    return parser.parse_args()
+def parse_args():
+    parser = argparse.ArgumentParser(description="Export a HF checkpoint (with ModelOpt state) for deployment.")
+    parser.add_argument("--model_path", type=str, required=True, help="Path or HF hub id of the trained checkpoint.")
+    parser.add_argument("--export_path", type=str, required=True, help="Destination directory for exported files.")
+    return parser.parse_args()

34-41: Set eval mode prior to export

Ensures deterministic modules (e.g., dropout) are disabled during the dummy forward used in export.

 args = parse_args()
-model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
+model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
+model.eval()
 with torch.inference_mode():
     export_hf_checkpoint(
         model,  # The quantized model.
         export_dir=args.export_path,  # The directory where the exported files will be stored.
     )
modelopt/torch/export/unified_export_hf.py (2)

497-503: Make _quant_applied resilient to missing keys

Minor polish: use dict.get to avoid KeyError if the structure changes.

-def _quant_applied(hf_quant_config: dict) -> bool:
-    """Check if any quantization is applied."""
-    return not (
-        hf_quant_config["quantization"]["quant_algo"] == QUANTIZATION_NONE
-        and not hf_quant_config["quantization"]["quantized_layers"]
-    )
+def _quant_applied(hf_quant_config: dict) -> bool:
+    """Check if any quantization is applied."""
+    q = hf_quant_config.get("quantization", {})
+    return not (q.get("quant_algo") == QUANTIZATION_NONE and not q.get("quantized_layers"))

30-33: Optional: avoid hard import if plugins are unused

If plugins are only relevant for Eagle exports, consider importing inside export_hf_checkpoint to reduce import-time side effects and avoid potential cycles.

examples/speculative_decoding/train_eagle3_and_export.sh (4)

53-53: Fix ShellCheck SC2155: avoid command substitution in export

Assign first, then export to prevent masking return values.

-  export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((NUM_GPU-1)))
+  devs="$(seq -s, 0 $((NUM_GPU-1)))"
+  export CUDA_VISIBLE_DEVICES="$devs"

58-66: Create output dirs proactively

Avoid relying on downstream scripts to mkdir. Harmless if they already exist.

 echo "==== [1/3] Training draft model ===="
 OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
+mkdir -p "$(dirname "$OUTPUT_DIR")"
 ./launch_train.sh --model $BASE_MODEL \
             --output_dir $OUTPUT_DIR \
             --data $DATA \
             --num_gpu $NUM_GPU \
             --num_epochs 2 \
             --eagle_config eagle_config.json

70-72: Also mkdir for export path

Prevents failures if parent dir is missing.

 echo "==== [3/3] Exporting checkpoint to deployment format ===="
 EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
+mkdir -p "$(dirname "$EXPORT_PATH")"
 python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH

25-41: Minor: align flag name in comment with implementation

The comment says --base-model but code uses --base_model.

examples/speculative_decoding/README.md (11)

7-9: Fix grammar and proper names in intro.

Use articles and brand casing; current phrasing is awkward.

-This folder contains end-to-end runnable speculative decoding fine-tuning pipeline where Llama3.2-1B from huggingface is trained on Daring-Anteater dataset.
-
-This example focus on training with HF. To train with Megatron-LM, please refer to [this link](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt) in Megatron-LM repo.
+This folder contains an end-to-end runnable speculative decoding fine‑tuning pipeline in which Llama‑3.2‑1B (Hugging Face) is trained on the Daring‑Anteater dataset.
+
+This example focuses on training with Hugging Face. To train with Megatron‑LM, see the [Megatron‑LM example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt).

65-71: Avoid suggesting hardcoded API keys in examples.

Remove the fake key or clearly denote an env var. This reduces the chance of users pasting secrets.

-pip install vllm
-vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000  --tensor-parallel-size 1
+pip install vllm
+# If your deployment requires an API key, set VLLM_API_KEY in the environment instead of hardcoding.
+VLLM_API_KEY=... vllm serve meta-llama/Llama-3.2-1B-Instruct --port 8000 --tensor-parallel-size 1

90-96: Tighten wording and fix minor grammar.

-For eagle1 and eagle3 we provide an [default model architecture config](.../default_config.py#L18) in modelopt. User can overwrite default settings by providing additional json dict. In this example, we overwrite the `draft_vocab_size` by in `eagle_config.json`:
+For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](.../default_config.py#L18) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`:

100-108: Fix typos (“tokenzier”, “hugginface”) and add clarity.

-`main.py` provides a example for converting a base HF model for speculative decoding and training it. It consists of a few simple steps:
-First, load base model and tokenzier from hugginface:
+`main.py` provides an example for converting a base HF model for speculative decoding and training it. It consists of a few simple steps:
+First, load the base model and tokenizer from Hugging Face:

118-121: Typo in key name placeholder.

-# overwrite config with custom config
-config["eagle_architecture_config"].update({"<overwrite_kyes>": "<overwrite_values>"})
+# Overwrite config with custom config
+config["eagle_architecture_config"].update({"<overwrite_keys>": "<overwrite_values>"})

131-136: Fix typo and verify API alias.

Spelling: “deocoding” → “decoding”. Also, please confirm mtsp.convert is the correct symbol/alias in this repo.

-Then, we convert model to a speculative deocoding model:
+Then, we convert the model to a speculative decoding model:

140-147: Avoid private Trainer APIs; rely on Trainer for device placement.

trainer._move_model_to_device is a private method and may change; Trainer already handles device placement. Also confirm the correct path for enable_huggingface_checkpointing().

-trainer._move_model_to_device(model, trainer.args.device)
-
 # Enable HF checkpointing so that the saved model will contain the speculative decoding module
-mto.enable_huggingface_checkpointing()
+from modelopt.torch import export as mto_export  # adjust import as needed
+mto_export.enable_huggingface_checkpointing()

152-161: Fix spelling and tighten phrasing in launch instructions.

-... along with a bash script to launch the training with huggingface accelrate in `launch_train.sh`, which can be runned by:
+... along with a bash script to launch training with Hugging Face Accelerate in `launch_train.sh`, which can be run by:

189-197: Minor YAML lead‑in punctuation and naming consistency.

-To serve the checkpoint with trtllm, we can run trtllm-serve with:
+To serve the checkpoint with TRT‑LLM, run `trtllm-serve`:
 ...
-,
- with `extra-llm-api-config.yml` being
+with `extra-llm-api-config.yml`:

224-233: Brand/style fixes in Support Matrix.

-| LLAMA 2 | ✅ | ✅ | ✅ |
-| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
+| Llama 2 | ✅ | ✅ | ✅ |
+| Llama 3, 3.1 | ✅ | ✅ | ✅ |
 ...
-| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
+| Qwen 1.5, 2, 2.5 | ✅ | ✅ | ✅ |

236-238: Use correct company casing (“NVIDIA”).

-Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - Nvidia TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]
+Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - NVIDIA TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]
examples/speculative_decoding/eagle_config.json (1)

2-8: Align rope_scaling schema with expected HF conventions (consider adding "type": "dynamic").

Given the presence of low_freq_factor, high_freq_factor, and original_max_position_embeddings, many HF configs expect rope_scaling.type="dynamic". If the exporter relies on HF-compatible fields, add the type or confirm your plugin translates this correctly.

Apply this minimal diff if compatible with your pipeline:

   "rope_scaling": {
+        "type": "dynamic",
         "factor": 32.0,
         "low_freq_factor": 1.0,
         "high_freq_factor": 4.0,
         "original_max_position_embeddings": 8192,
         "rope_type": "llama3"
   },
examples/speculative_decoding/server_generate.py (1)

56-56: Make --system_prompt truly optional to avoid odd joining behavior.

With nargs="+" and default "", " ".join(args.system_prompt) can behave unexpectedly. Prefer a list default.

Apply:

-parser.add_argument("--system_prompt", nargs="+", type=str, default="", help="System prompt")
+parser.add_argument("--system_prompt", nargs="*", type=str, default=[], help="System prompt")

system_prompt = " ".join(args.system_prompt) will remain correct (empty string when not provided).

Also applies to: 185-186

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 1cf78b2 and 47a0a50.

📒 Files selected for processing (13)
  • examples/speculative_decoding/README.md (2 hunks)
  • examples/speculative_decoding/ar_validate.py (2 hunks)
  • examples/speculative_decoding/calibrate_draft_vocab.py (2 hunks)
  • examples/speculative_decoding/eagle_config.json (1 hunks)
  • examples/speculative_decoding/export_hf_checkpoint.py (1 hunks)
  • examples/speculative_decoding/launch.sh (0 hunks)
  • examples/speculative_decoding/main.py (3 hunks)
  • examples/speculative_decoding/server_generate.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (1 hunks)
  • modelopt/torch/export/plugins/__init__.py (1 hunks)
  • modelopt/torch/export/plugins/hf_spec_export.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (4 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/launch.sh
🧰 Additional context used
🧬 Code graph analysis (6)
examples/speculative_decoding/export_hf_checkpoint.py (2)
modelopt/torch/export/unified_export_hf.py (1)
  • export_hf_checkpoint (505-557)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
examples/speculative_decoding/calibrate_draft_vocab.py (1)
modelopt/torch/speculative/utils.py (1)
  • calibrate_frequent_vocab (31-45)
modelopt/torch/export/plugins/hf_spec_export.py (1)
modelopt/torch/speculative/plugins/transformers.py (1)
  • HFEagleModel (333-1138)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/export/plugins/hf_spec_export.py (2)
  • rename_and_prune_if_spec_decoding (55-80)
  • set_config_if_spec_decoding (83-151)
examples/speculative_decoding/main.py (1)
modelopt/torch/export/model_config.py (1)
  • max_position_embeddings (603-605)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
🪛 Shellcheck (0.10.0)
examples/speculative_decoding/train_eagle3_and_export.sh

[warning] 53-53: Declare and assign separately to avoid masking return values.

(SC2155)

🪛 LanguageTool
examples/speculative_decoding/README.md

[grammar] ~7-~7: There might be a mistake here.
Context: ...ntly improving throughput. This folder contains end-to-end runnable speculative decodin...

(QB_NEW_EN)


[grammar] ~7-~7: There might be a mistake here.
Context: ...Llama3.2-1B from huggingface is trained on Daring-Anteater dataset. This example ...

(QB_NEW_EN)


[grammar] ~9-~9: There might be a mistake here.
Context: ...e/main/examples/post_training/modelopt) in Megatron-LM repo. ## Contents <div al...

(QB_NEW_EN)


[grammar] ~15-~15: There might be a mistake here.
Context: ...tion** | Description | Jump To | | :------------: | :------------: | :---...

(QB_NEW_EN)


[grammar] ~16-~16: There might be a mistake here.
Context: ...---: | :------------: | :------------: | | Pre-Requisites | Required & optional d...

(QB_NEW_EN)


[grammar] ~17-~17: There might be a mistake here.
Context: ...ndencies | [Link] | | Simplified Workflow | Train, evaluate ...

(QB_NEW_EN)


[grammar] ~18-~18: There might be a mistake here.
Context: ...getting-started-simplified-workflow)] | | Complete Workflow | Full example with ...

(QB_NEW_EN)


[grammar] ~19-~19: Ensure spelling is correct
Context: ...rkflow | Full example with configurable traininig pipeline | [Link] ...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)


[grammar] ~19-~19: There might be a mistake here.
Context: ...pipeline | [Link] | | Support Matrix | Supported models for ...

(QB_NEW_EN)


[grammar] ~20-~20: There might be a mistake here.
Context: ...ing training | [Link] | | Speculation Module Checkpoints | View ...

(QB_NEW_EN)


[grammar] ~21-~21: There might be a mistake here.
Context: ...nk](#speculation-module-checkpoints)] | | Resources | Extra links to relevant re...

(QB_NEW_EN)


[grammar] ~50-~50: There might be a mistake here.
Context: ...datasets/nvidia/Daring-Anteater) dataset - Evaluates the acceptance rate on [MT-Ben...

(QB_NEW_EN)


[grammar] ~72-~72: There might be a mistake here.
Context: ...odels. Then, we generate conversations with base model and prompts from Daring-Ante...

(QB_NEW_EN)


[grammar] ~80-~80: There might be a mistake here.
Context: ...cabulary Compression We can optionally use smaller vocab size for the draft model ...

(QB_NEW_EN)


[grammar] ~90-~90: There might be a mistake here.
Context: ... User can overwrite default settings by providing additional json dict. In this example, ...

(QB_NEW_EN)


[grammar] ~100-~100: There might be a mistake here.
Context: ...g it. It consists of a few simple steps: First, load base model and tokenzier fro...

(QB_NEW_EN)


[grammar] ~101-~101: There might be a mistake here.
Context: ... consists of a few simple steps: First, load base model and tokenzier from hugginfac...

(QB_NEW_EN)


[grammar] ~101-~101: Ensure spelling is correct
Context: ...imple steps: First, load base model and tokenzier from hugginface: ```python model = tra...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)


[grammar] ~109-~109: There might be a mistake here.
Context: ...to your pretrained model>" ) ``` Then, load default eagle config and make necessary...

(QB_NEW_EN)


[grammar] ~109-~109: There might be a mistake here.
Context: ...`` Then, load default eagle config and make necessary overwrites: ```python # Load...

(QB_NEW_EN)


[grammar] ~131-~131: There might be a mistake here.
Context: ...ition_embeddings, } ) ``` Then, we convert model to a speculative deocoding model:...

(QB_NEW_EN)


[grammar] ~131-~131: Ensure spelling is correct
Context: ...Then, we convert model to a speculative deocoding model: ```python mtsp.convert(model, [...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)


[grammar] ~137-~137: There might be a mistake here.
Context: ... training forward, making it compatible with HF trainer: ```python # Create a train...

(QB_NEW_EN)


[grammar] ~167-~167: There might be a mistake here.
Context: ...Q and QAT. ### Model Validation After training draft model, we can evaluate the saved ...

(QB_NEW_EN)


[grammar] ~226-~226: There might be a mistake here.
Context: ... | Model | Medusa | EAGLE1/2 | EAGLE3 | | :---: | :---: | :---: | :---: | | LLAM...

(QB_NEW_EN)


[grammar] ~227-~227: There might be a mistake here.
Context: ...GLE3 | | :---: | :---: | :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 |...

(QB_NEW_EN)


[grammar] ~228-~228: There might be a mistake here.
Context: ... :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral |...

(QB_NEW_EN)


[grammar] ~229-~229: There might be a mistake here.
Context: ...✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ ...

(QB_NEW_EN)


[grammar] ~230-~230: There might be a mistake here.
Context: ....1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 |...

(QB_NEW_EN)


[grammar] ~231-~231: There might be a mistake here.
Context: ...tral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | ## Specu...

(QB_NEW_EN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: linux
🔇 Additional comments (18)
examples/speculative_decoding/ar_validate.py (4)

29-29: LGTM! Improved default values for AR validation.

The increase in num_samples from 20 to 80 provides better statistical reliability for AR validation results.


57-57: LGTM! More reasonable default for validation steps.

Changing the default steps from 1 to 3 provides better validation coverage while remaining computationally reasonable.


59-59: LGTM! Adjusted OSL for better performance balance.

Reducing the default osl from 100 to 32 should improve validation speed while maintaining reasonable sequence length coverage.


62-62: LGTM! Consistent with function signature update.

The CLI default now matches the updated function signature default of 80 samples.

examples/speculative_decoding/calibrate_draft_vocab.py (2)

31-35: LGTM! Better separation of concerns.

Moving draft_vocab_size from the config file to a direct CLI parameter provides clearer interface separation and makes the parameter more explicit.


55-55: LGTM! Consistent with the CLI parameter change.

The function call now correctly uses the CLI argument instead of reading from the config file.

examples/speculative_decoding/main.py (3)

50-56: LGTM! Proper optional wandb integration.

The graceful handling of wandb import ensures the script continues to work when wandb is not available, while enabling enhanced logging when it is present.


180-181: LGTM! Essential configuration propagation.

Propagating max_position_embeddings from the base model to the Eagle architecture config ensures deployment compatibility and proper model configuration alignment.


225-226: LGTM! Conditional wandb logging.

The conditional logging ensures AR validation metrics are captured when wandb is available without causing errors when it's not installed.

modelopt/torch/speculative/plugins/transformers.py (5)

688-689: Good change: pass through extra kwargs.

Allows future-proofing without breaking call sites.


791-804: Base-model outputs fast-path looks good.

Nice flexibility to consume teacher-provided tensors and skip recompute.


1079-1080: Good device alignment for base_token.

Avoids device mismatch in downstream concat.


1156-1157: Good: ensure device match in AR validation.

Prevents accidental CPU/GPU mismatch during concatenation.


1021-1029: Verify no positional unpacking of ModelOutput
Ensure no downstream code destructures or uses .to_tuple() assuming the previous tuple shape, as adding train_acc changes its length and order.

modelopt/torch/export/plugins/__init__.py (1)

22-24: Re-export looks fine under transformers plugin guard.

Keeps import-time side effects contained to environments with transformers installed.

examples/speculative_decoding/README.md (1)

41-55: Nice addition: one‑liner workflow with training + export.

Clear and actionable; the defaults link helps users understand what’s applied.

examples/speculative_decoding/eagle_config.json (1)

9-9: Verify initializer_range matches the base model’s setting.

A mismatch can affect initialization and training stability; ensure 0.02 is consistent with the chosen base model config.

examples/speculative_decoding/server_generate.py (1)

49-49: Use argparse.BooleanOptionalAction for --chat instead of type=bool
Replace in examples/speculative_decoding/server_generate.py (around line 49):

-parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
+parser.add_argument(
+    "--chat",
+    action=argparse.BooleanOptionalAction,
+    default=True,
+    help="Use chat mode",
+)

Confirm your minimum Python version is ≥ 3.9 so BooleanOptionalAction exists, and update any docs or scripts to reference --no-chat where needed.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
examples/speculative_decoding/launch_train.sh (4)

95-97: Sanitize OUTPUT_DIR model basename to avoid filesystem issues.

Some model names contain characters problematic for paths. Sanitize the basename before composing OUTPUT_DIR.

 MODEL_BASENAME=$(basename "$MODEL")
-OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
+MODEL_BASENAME_SAFE="${MODEL_BASENAME//[^a-zA-Z0-9._-]/_}"
+OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME_SAFE}-$(date +%Y%m%d_%H%M)"}

77-83: Don’t hard-fail on unknown flags; forward them to the training script.

Exiting on first unknown arg prevents users from passing through Hugging Face/Accelerate options. Collect and forward instead.

-    *)
-      >&2 printf "Error: Invalid argument ${1#*=}\n"
-      exit 1
-      ;;
+    *)
+      EXTRA_ARGS+=("$1")
+      ;;

Add initialization near the top (before the while-loop):

 set -eo pipefail
+
+EXTRA_ARGS=()

130-153: Plumb through extra args to CMD.

Forward any collected unknown args to main.py.

     --data_path $DATA \
-    $SPECULATIVE_ARGS
+    $SPECULATIVE_ARGS \
+    ${EXTRA_ARGS[@]:-}

141-142: Make --do_eval overridable while defaulting to False.

Hardcoding --do_eval False blocks standalone use; keep False by default but allow override via a parsed --do_eval.

-    --do_eval False \
+    --do_eval ${DO_EVAL:-False} \

Add a parser case (outside this hunk) to honor --do_eval:

# In the case/esac parser:
--do_eval*)
  if [[ "$1" != *=* ]]; then shift; fi
  DO_EVAL="${1#*=}"
  ;;
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 47a0a50 and 52dbdec.

📒 Files selected for processing (2)
  • examples/speculative_decoding/launch_train.sh (4 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/speculative/plugins/transformers.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
examples/speculative_decoding/launch_train.sh (2)

151-152: Validate required data path early.

If DATA is empty, --data_path will be passed without a value and the run will fail late. Fail fast with a clear message or provide a sensible default.

# Add before building CMD:
if [[ -z "${DATA:-}" ]]; then
  >&2 echo "Error: --data <path> is required (set via --data=...)." 
  exit 1
fi

128-129: Good call disabling tokenizers parallelism.

This avoids noisy warnings and thread oversubscription in tokenization-heavy phases.

This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
In this case, the first 40 shards of data will be processed.
To process the next 40 shards
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to explain how this mapping works. i.e., tokens = draft_tokens + d2t[draft_tokens]

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-readme branch from cf17899 to 56fafda Compare September 6, 2025 01:05
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/speculative_decoding/README.md (2)

137-149: Define mtsp/mto imports before use.
Make the snippet self-contained.

-```python
-mtsp.convert(model, [("eagle", config)])
-```
-...
-```python
-# Enable HF checkpointing so that the saved model will contain the speculative decoding module
-mto.enable_huggingface_checkpointing()
-```
+```python
+import modelopt.torch.speculative as mtsp
+import modelopt.torch.export.unified_export_hf as mto
+
+mtsp.convert(model, [("eagle", config)])
+...
+# Enable HF checkpointing so that the saved model will contain the speculative decoding module
+mto.enable_huggingface_checkpointing()
+```

143-151: Avoid private Trainer API and undefined variable.

  • _move_model_to_device is private; Trainer handles device placement.
  • checkpoint isn’t defined; omit unless documented.
-trainer._move_model_to_device(model, trainer.args.device)
-
 # Enable HF checkpointing so that the saved model will contain the speculative decoding module
 mto.enable_huggingface_checkpointing()
 
-trainer.train(resume_from_checkpoint=checkpoint)
+trainer.train()
 trainer.save_state()
♻️ Duplicate comments (4)
modelopt/torch/export/unified_export_hf.py (1)

529-545: Guard spec-decoding transforms when ModelOpt state is absent

Calling spec-decoding transforms unconditionally can break vanilla HF exports. Add a safe gate.

-        post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
+        if hasattr(model, "_modelopt_state") and getattr(model, "_modelopt_state") is not None:
+            post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
@@
-        config_data = set_config_if_spec_decoding(model, config_data)
+        if hasattr(model, "_modelopt_state") and getattr(model, "_modelopt_state") is not None:
+            config_data = set_config_if_spec_decoding(model, config_data)
examples/speculative_decoding/README.md (3)

19-20: Broken anchors from prior commit are now correct.
“Complete Workflow” and “Support Matrix” Jump To targets are fixed.


172-174: Code-fence language fixed.
CLI is now fenced as bash, not python.


181-185: Export section fences and copy now correct.
Language tag and “format” typo are fixed.

🧹 Nitpick comments (16)
examples/speculative_decoding/SLURM_prepare_data.md (1)

9-10: Quote job name; normalize spacing

Avoid parsing issues with ':' in job name; trim double spaces.

-salloc  -N4 -A <account> -p <partition>  -J <account>-synthetic:data-gen -t 120
+salloc -N4 -A <account> -p <partition> -J "<account>-synthetic:data-gen" -t 120
modelopt/torch/export/unified_export_hf.py (1)

517-526: Path handling nit: prefer Path operations over f-strings

Improves readability and OS-compatibility.

-        with open(f"{export_dir}/hf_quant_config.json", "w") as file:
+        with open(Path(export_dir) / "hf_quant_config.json", "w") as file:
             json.dump(hf_quant_config, file, indent=4)
@@
-        original_config = f"{export_dir}/config.json"
+        original_config = Path(export_dir) / "config.json"
         config_data = {}
@@
-        with open(original_config, "w") as file:
+        with open(original_config, "w") as file:
             json.dump(config_data, file, indent=4)

Also applies to: 536-547

modelopt/torch/speculative/plugins/transformers.py (4)

689-717: Consistent draft-vocab remap policy

_base path remaps logits only during training; teacher-provided path remaps unconditionally. Ensure both paths align with how downstream loss expects logits to be in draft space.

-        if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training:
+        if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
             assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized"
             base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)

If inference truly needs full-vocab logits, gate at call site instead.


720-726: Bounds/correctness checks and caching for mapping

Add sanity checks and cache the reverse index to avoid recomputing per call.

-    def _map_logits_to_draft_vocab(self, full_logits):
-        reverse_mapping = (
-            torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device)
-            + self.eagle_module.d2t
-        )
-        return full_logits[:, :, reverse_mapping]
+    def _map_logits_to_draft_vocab(self, full_logits):
+        d2t = self.eagle_module.d2t
+        draft = d2t.numel()
+        if not hasattr(self, "_reverse_draft_index") or self._reverse_draft_index.numel() != draft:
+            self._reverse_draft_index = torch.arange(draft, device=d2t.device) + d2t
+        # Assert indices are in range
+        assert self._reverse_draft_index.max().item() < full_logits.size(-1), "draft→full index OOB"
+        return full_logits.index_select(-1, self._reverse_draft_index)

826-831: Aux hidden states: validate presence/shape in teacher path

If use_aux_hidden_state is True and teacher path is used, assert aux_hidden_states is provided and has expected last-dim = len(layer_ids)*hidden_size to fail fast with a clear message.

-                if "base_model_outputs" in kwargs:
-                    aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"]
+                if "base_model_outputs" in kwargs:
+                    aux_hidden_states = kwargs["base_model_outputs"].get("aux_hidden_states")
+                    assert aux_hidden_states is not None, "aux_hidden_states required for EAGLE-3 teacher path"

790-803: Clarify DynamicCache initialization and document aux_hidden_states requirement
The past_key_values = None branch is already handled by DynamicCache.from_legacy_cache(None), so no extra guard is required; consider adding an inline comment (# will initialize a fresh DynamicCache below) for clarity. Update the EAGLE-3 docs to specify that when use_aux_hidden_state is enabled, base_model_outputs must include an aux_hidden_states tensor shaped to match eagle_aux_hidden_state_layer_ids.

examples/speculative_decoding/README.md (10)

28-28: Use consistent capitalization: “ModelOpt”.
Brand is capitalized elsewhere; align here.

-Install Modelopt with `hf` dependencies and other requirements for this example:
+Install ModelOpt with `hf` dependencies and other requirements for this example:

47-47: Capitalize “ModelOpt”.

-This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt.
+This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in ModelOpt.

83-90: Explain d2t mapping with a concrete relation.
Add a short sentence showing how the mapping composes tokens.

-This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later.
+This produces a `d2t.pt` file in `save_dir`, a mapping from draft vocab to full vocab used by the draft model. Conceptually: `tokens_full = tokens_draft + d2t[tokens_draft]`.

101-101: Capitalize “ModelOpt”.

-### Training Draft Model with Modelopt
+### Training Draft Model with ModelOpt

106-110: Import missing dependency in snippet.
Readers will copy/paste this; include the import.

-```python
-model = transformers.AutoModelForCausalLM.from_pretrained(
+```python
+import transformers
+model = transformers.AutoModelForCausalLM.from_pretrained(

166-170: Capitalize “ModelOpt”.

-The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
+The saved ModelOpt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
-After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
+After training the draft model, evaluate the saved ModelOpt checkpoint on MT‑Bench with:

227-235: Polish model names and add “last updated”.
Use canonical casing and note when the matrix was last validated.

-| LLAMA 2 | ✅ | ✅ | ✅ |
-| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
+| Llama 2 | ✅ | ✅ | ✅ |
+| Llama 3, 3.1 | ✅ | ✅ | ✅ |
 | Mistral | ✅ | ✅ | ✅ |
 | Phi 3 | ✅ | ✅ | ✅ |
-| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
+| Qwen 1.5, 2, 2.5 | ✅ | ✅ | ✅ |

Optionally add a line above the table: “Last validated: 2025‑09‑05.”


93-99: Clarify EAGLE-1 vs EAGLE-3 default configs and update link

  • In the README example, call out which settings enable EAGLE-3 (use_aux_hidden_state, eagle_aux_hidden_state_layer_ids, use_mtp_layernorm, eagle_disable_moe, eagle_hidden_state_distillation, etc.) and update the URL to point directly at the EagleConfig default JSON in the ModelOpt API docs.

65-71: vLLM flags and quantization option are up-to-date

  • --api-key (or VLLM_API_KEY) is the current auth flag.
  • --quantization=modelopt has been supported since v0.6.5.

Consider showing environment-based auth (VLLM_API_KEY/BASE_URL) in examples instead of embedding tokens.


194-215: Update README example to match current trtllm-serve flags and YAML schema

  • Confirm the CLI flag --extra_llm_api_options is still supported in your installed version (it appears in the latest docs).
  • Under speculative_config, use the up-to-date keys: speculative_decoding_mode (alias decoding_type), max_draft_len, speculative_model_dir, etc.
  • Under kv_cache_config, include valid fields such as enable_block_reuse, enable_partial_reuse, copy_on_partial_reuse, free_gpu_memory_fraction, and any other keys required by your target release.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 52dbdec and 56fafda.

📒 Files selected for processing (7)
  • examples/speculative_decoding/README.md (2 hunks)
  • examples/speculative_decoding/SLURM_prepare_data.md (1 hunks)
  • examples/speculative_decoding/export_hf_checkpoint.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (1 hunks)
  • modelopt/torch/export/plugins/hf_spec_export.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (4 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • examples/speculative_decoding/export_hf_checkpoint.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • examples/speculative_decoding/train_eagle3_and_export.sh
🧰 Additional context used
🪛 LanguageTool
examples/speculative_decoding/README.md

[grammar] ~21-~21: There might be a mistake here.
Context: ...nk](#speculation-module-checkpoints)] | | Resources | Extra links to relevant re...

(QB_NEW_EN)


[grammar] ~50-~50: There might be a mistake here.
Context: ...datasets/nvidia/Daring-Anteater) dataset - Evaluates the acceptance rate on [MT-Ben...

(QB_NEW_EN)


[grammar] ~71-~71: There might be a mistake here.
Context: ...odels. Then, we generate conversations with base model and prompts from Daring-Ante...

(QB_NEW_EN)


[grammar] ~83-~83: There might be a mistake here.
Context: ...cabulary Compression We can optionally use smaller vocab size for the draft model ...

(QB_NEW_EN)


[grammar] ~103-~103: There might be a mistake here.
Context: ...g it. It consists of a few simple steps: First, load the base model and tokenizer...

(QB_NEW_EN)


[grammar] ~112-~112: There might be a mistake here.
Context: ...to your pretrained model>" ) ``` Then, load default eagle config and make necessary...

(QB_NEW_EN)


[grammar] ~112-~112: There might be a mistake here.
Context: ...`` Then, load default eagle config and make necessary overwrites: ```python # Load...

(QB_NEW_EN)


[grammar] ~134-~134: There might be a mistake here.
Context: ...ition_embeddings, } ) ``` Then, we convert model to a speculative decoding model: ...

(QB_NEW_EN)


[grammar] ~140-~140: There might be a mistake here.
Context: ... training forward, making it compatible with HF trainer: ```python # Create a train...

(QB_NEW_EN)


[grammar] ~170-~170: There might be a mistake here.
Context: ...Q and QAT. ### Model Validation After training draft model, we can evaluate the saved ...

(QB_NEW_EN)


[grammar] ~229-~229: There might be a mistake here.
Context: ... | Model | Medusa | EAGLE1/2 | EAGLE3 | | :---: | :---: | :---: | :---: | | LLAM...

(QB_NEW_EN)


[grammar] ~230-~230: There might be a mistake here.
Context: ...GLE3 | | :---: | :---: | :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 |...

(QB_NEW_EN)


[grammar] ~231-~231: There might be a mistake here.
Context: ... :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral |...

(QB_NEW_EN)


[grammar] ~232-~232: There might be a mistake here.
Context: ...✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ ...

(QB_NEW_EN)


[grammar] ~233-~233: There might be a mistake here.
Context: ....1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 |...

(QB_NEW_EN)


[grammar] ~234-~234: There might be a mistake here.
Context: ...tral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | ## Specu...

(QB_NEW_EN)

examples/speculative_decoding/SLURM_prepare_data.md

[grammar] ~3-~3: There might be a mistake here.
Context: ...eneration we provide some SLURM support. Assuming a $SLURM_JOB_ID is present an...

(QB_NEW_EN)


[grammar] ~24-~24: There might be a mistake here.
Context: ...generate.pyanddistributed_generate`. This will launch a vllm server (sglang i...

(QB_NEW_EN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/export/unified_export_hf.py (1)

30-33: LGTM: clean plugin import

Importing spec-decoding helpers here aligns the export flow with plugins.

modelopt/torch/speculative/plugins/transformers.py (2)

1082-1083: LGTM: device alignment fix

Moving base_token to input_ids.device avoids cross-device concat errors.


1160-1160: LGTM: device-safe concatenation in AR validation

Ensures both tensors share device during growth loop.

examples/speculative_decoding/README.md (3)

5-5: Intro definition reads well; keep.
Clear and technically accurate overview of speculative decoding.


239-241: Checkpoint links look good.
Clear, actionable pointers to deployable artifacts.


79-79: SLURM guide file present
The referenced SLURM_prepare_data.md exists in examples/speculative_decoding/—no link changes needed.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/export/unified_export_hf.py (1)

520-548: Guard quantization config writes
In modelopt/torch/export/unified_export_hf.py (lines 520–548), wrap the saving of hf_quant_config.json and the insertion of "quantization_config" into config.json behind a check for applied quantization (e.g. if hf_quant_config:), so vanilla HF exports don’t emit empty quant artifacts.

modelopt/torch/speculative/plugins/transformers.py (1)

1154-1160: Avoid hard-coding CUDA device in validation.

Use the input tensor’s device; current_device() breaks on CPU-only or different device placement.

-        input_ids = copy.deepcopy(input_ids).to(torch.cuda.current_device())
+        input_ids = copy.deepcopy(input_ids).to(input_ids.device)
@@
-            input_ids = torch.cat((input_ids, input_id.to(input_ids.device)), dim=-1)
+            input_ids = torch.cat((input_ids, input_id.to(input_ids.device)), dim=-1)
♻️ Duplicate comments (4)
modelopt/torch/export/unified_export_hf.py (2)

544-544: Also guard spec-config transform to avoid AttributeError.

Mirror the pruning guard for config transform.

(Implemented in the diff above.)


529-530: Guard spec-decoding state-dict pruning (works on vanilla HF models).

Call pruning only when ModelOpt state exists.

-        post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
+        if hasattr(model, "_modelopt_state"):
+            post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
modelopt/torch/export/plugins/hf_spec_export.py (2)

87-95: Guard _modelopt_state access to avoid crashes on vanilla HF.

Directly indexing model._modelopt_state can raise AttributeError. Use the same guard pattern as in rename_and_prune_if_spec_decoding.

-def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
+def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
@@
-    if len(model._modelopt_state) != 1 or model._modelopt_state[0][0] != "eagle":
-        # return as is
-        return config_data
+    opt_modes = getattr(model, "_modelopt_state", None)
+    if (
+        not isinstance(opt_modes, (list, tuple))
+        or len(opt_modes) != 1
+        or opt_modes[0][0] != "eagle"
+    ):
+        return config_data

95-155: Merge with existing config instead of overwriting it.

Dropping unknown keys can remove fields like quantization_config; merge template into incoming config and deep-merge eagle_config.

-    return template_config
+    # Merge: preserve existing fields while overriding with official template values.
+    merged = {**config_data, **template_config}
+    merged["eagle_config"] = {
+        **config_data.get("eagle_config", {}),
+        **template_config["eagle_config"],
+    }
+    return merged
🧹 Nitpick comments (10)
examples/speculative_decoding/SLURM_prepare_data.md (4)

3-5: Tighten wording and fix grammar.

Minor clarity nits.

-For basic parallelization of synthetic data generation we provide some SLURM support.
-Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable.
+For basic parallelization of synthetic data generation, we provide SLURM support.
+Assuming `$SLURM_JOB_ID` is present and nodes n1, n2, n3, n4 are allocated, the following is achievable.

12-16: Polish sentence casing.

-Create shards of some given size
+Create shards of a given size

24-26: Capitalize framework names.

-This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10*max_lines_per_shard number of samples).
+This launches a vLLM server (SGLang is also available) on each node. Each node will work through 10 shards of data (10*max_lines_per_shard samples).

29-31: Keep argument list consistent between runs.

The second launch example omits the system prompt argument; include it or add a note that it’s optional to avoid confusion.

-bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4
+bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""
examples/speculative_decoding/README.md (5)

18-18: Capitalize EAGLE and tighten phrasing.

-| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
+| Simplified Workflow | Train, evaluate, and export the EAGLE model with a one‑line command | \[[Link](#getting-started-simplified-workflow)\] |

35-39: Add Git LFS note for dataset clone.

HF datasets via git require LFS; add a one‑liner to reduce user friction.

 ```bash
+git lfs install
 git clone https://huggingface.co/datasets/nvidia/Daring-Anteater

---

`69-69`: **Clarify note reads better as a flag tip.**



```diff
-Note: Add `--quantization=modelopt` flag for quantized models.
+Tip: Add `--quantization=modelopt` when serving quantized models.

101-101: Consistent branding: ModelOpt.

-### (Optional) Configuring Draft Model
+### (Optional) Configuring Draft Model (ModelOpt)

227-235: Model names casing and minor formatting.

-| LLAMA 2 | ✅ | ✅ | ✅ |
-| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
+| Llama 2 | ✅ | ✅ | ✅ |
+| Llama 3, 3.1 | ✅ | ✅ | ✅ |
modelopt/torch/speculative/plugins/transformers.py (1)

713-717: Guard d2t usage with a clearer error.

Assertion is fine, but a ValueError is more actionable when draft vocab remapping is expected.

-            assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized"
+            if not hasattr(self.eagle_module, "d2t"):
+                raise ValueError("Draft‑vocab remapping requested but eagle_module.d2t is not initialized")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 52dbdec and 56fafda.

📒 Files selected for processing (7)
  • examples/speculative_decoding/README.md (2 hunks)
  • examples/speculative_decoding/SLURM_prepare_data.md (1 hunks)
  • examples/speculative_decoding/export_hf_checkpoint.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (1 hunks)
  • modelopt/torch/export/plugins/hf_spec_export.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (4 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/speculative_decoding/train_eagle3_and_export.sh
  • examples/speculative_decoding/export_hf_checkpoint.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/README.md
🧬 Code graph analysis (3)
modelopt/torch/export/plugins/hf_spec_export.py (1)
modelopt/torch/speculative/plugins/transformers.py (1)
  • HFEagleModel (333-1141)
modelopt/torch/export/unified_export_hf.py (2)
modelopt/torch/export/plugins/hf_spec_export.py (2)
  • rename_and_prune_if_spec_decoding (54-84)
  • set_config_if_spec_decoding (87-155)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
🪛 LanguageTool
examples/speculative_decoding/README.md

[grammar] ~21-~21: There might be a mistake here.
Context: ...nk](#speculation-module-checkpoints)] | | Resources | Extra links to relevant re...

(QB_NEW_EN)


[grammar] ~50-~50: There might be a mistake here.
Context: ...datasets/nvidia/Daring-Anteater) dataset - Evaluates the acceptance rate on [MT-Ben...

(QB_NEW_EN)


[grammar] ~71-~71: There might be a mistake here.
Context: ...odels. Then, we generate conversations with base model and prompts from Daring-Ante...

(QB_NEW_EN)


[grammar] ~83-~83: There might be a mistake here.
Context: ...cabulary Compression We can optionally use smaller vocab size for the draft model ...

(QB_NEW_EN)


[grammar] ~103-~103: There might be a mistake here.
Context: ...g it. It consists of a few simple steps: First, load the base model and tokenizer...

(QB_NEW_EN)


[grammar] ~112-~112: There might be a mistake here.
Context: ...to your pretrained model>" ) ``` Then, load default eagle config and make necessary...

(QB_NEW_EN)


[grammar] ~112-~112: There might be a mistake here.
Context: ...`` Then, load default eagle config and make necessary overwrites: ```python # Load...

(QB_NEW_EN)


[grammar] ~134-~134: There might be a mistake here.
Context: ...ition_embeddings, } ) ``` Then, we convert model to a speculative decoding model: ...

(QB_NEW_EN)


[grammar] ~140-~140: There might be a mistake here.
Context: ... training forward, making it compatible with HF trainer: ```python # Create a train...

(QB_NEW_EN)


[grammar] ~170-~170: There might be a mistake here.
Context: ...Q and QAT. ### Model Validation After training draft model, we can evaluate the saved ...

(QB_NEW_EN)


[grammar] ~229-~229: There might be a mistake here.
Context: ... | Model | Medusa | EAGLE1/2 | EAGLE3 | | :---: | :---: | :---: | :---: | | LLAM...

(QB_NEW_EN)


[grammar] ~230-~230: There might be a mistake here.
Context: ...GLE3 | | :---: | :---: | :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 |...

(QB_NEW_EN)


[grammar] ~231-~231: There might be a mistake here.
Context: ... :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral |...

(QB_NEW_EN)


[grammar] ~232-~232: There might be a mistake here.
Context: ...✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ ...

(QB_NEW_EN)


[grammar] ~233-~233: There might be a mistake here.
Context: ....1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 |...

(QB_NEW_EN)


[grammar] ~234-~234: There might be a mistake here.
Context: ...tral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | ## Specu...

(QB_NEW_EN)

examples/speculative_decoding/SLURM_prepare_data.md

[grammar] ~3-~3: There might be a mistake here.
Context: ...eneration we provide some SLURM support. Assuming a $SLURM_JOB_ID is present an...

(QB_NEW_EN)


[grammar] ~24-~24: There might be a mistake here.
Context: ...generate.pyanddistributed_generate`. This will launch a vllm server (sglang i...

(QB_NEW_EN)

🔇 Additional comments (3)
modelopt/torch/export/plugins/hf_spec_export.py (1)

54-66: Tolerate models without _modelopt_state.

rename_and_prune_if_spec_decoding is already guarded—good.

modelopt/torch/speculative/plugins/transformers.py (2)

680-692: Forward signature change looks good.

Accepts **kwargs for extensibility while retaining existing args.


1022-1030: No downstream consumers of train_acc found. Merging is safe.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/README.md (1)

145-153: Don’t rely on private Trainer APIs; fix undefined checkpoint.

Trainer will move the model to device during training. Also checkpoint is undefined.

-trainer._move_model_to_device(model, trainer.args.device)
@@
-# Enable HF checkpointing so that the saved model will contain the speculative decoding module
+# Enable HF checkpointing so that the saved model will contain the speculative decoding module
 mto.enable_huggingface_checkpointing()
 
-trainer.train(resume_from_checkpoint=checkpoint)
+trainer.train()
 trainer.save_state()
 trainer.save_model("<path to the output directory>")
♻️ Duplicate comments (2)
examples/speculative_decoding/README.md (2)

15-22: TOC anchors now resolve correctly.

Broken anchors/typo from earlier review are fixed.


172-185: Code fences now correctly use bash and grammar is fixed.

Prior issues are resolved.

🧹 Nitpick comments (4)
examples/speculative_decoding/README.md (4)

49-53: Avoid line-number deep links that can drift.

Linking to ...default_config.py#L18 is brittle. Point to the file (or a permalink) instead.

-- Initializes the draft model with [default settings](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py#L18)
+- Initializes the draft model with [default settings](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py)

83-90: Great fix on the flag; tighten wording and clarify mapping.

The extra dashes issue is resolved. Suggest minor wording + add a one-liner showing how d2t is applied.

-We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
+We can optionally use a smaller draft vocabulary to speed up training/inference. For example, Llama‑3.2‑1B uses a 128,256‑token vocabulary. Here we build a 32k draft‑to‑target mapping by selecting the most frequent tokens in the training set:
@@
-This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later.
+This produces `d2t.pt` in `save_dir`, a draft‑to‑target mapping the draft model will load later (applied roughly as: accepted_tokens = draft_tokens + d2t[draft_tokens]).

93-99: Link stability.

config.py#L37 may drift; consider linking to the file (or a permalink).


166-167: Unify brand capitalization: ModelOpt.

Use “ModelOpt” consistently.

-The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
+The saved ModelOpt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 56fafda and ed1f5a3.

📒 Files selected for processing (2)
  • examples/speculative_decoding/README.md (2 hunks)
  • modelopt/torch/export/unified_export_hf.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/unified_export_hf.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • examples/speculative_decoding/README.md
🪛 LanguageTool
examples/speculative_decoding/README.md

[grammar] ~21-~21: There might be a mistake here.
Context: ...nk](#speculation-module-checkpoints)] | | Resources | Extra links to relevant re...

(QB_NEW_EN)


[grammar] ~50-~50: There might be a mistake here.
Context: ...datasets/nvidia/Daring-Anteater) dataset - Evaluates the acceptance rate on [MT-Ben...

(QB_NEW_EN)


[grammar] ~71-~71: There might be a mistake here.
Context: ...odels. Then, we generate conversations with base model and prompts from Daring-Ante...

(QB_NEW_EN)


[grammar] ~83-~83: There might be a mistake here.
Context: ...cabulary Compression We can optionally use smaller vocab size for the draft model ...

(QB_NEW_EN)


[grammar] ~103-~103: There might be a mistake here.
Context: ...g it. It consists of a few simple steps: First, load the base model and tokenizer...

(QB_NEW_EN)


[grammar] ~112-~112: There might be a mistake here.
Context: ...to your pretrained model>" ) ``` Then, load default eagle config and make necessary...

(QB_NEW_EN)


[grammar] ~112-~112: There might be a mistake here.
Context: ...`` Then, load default eagle config and make necessary overwrites: ```python # Load...

(QB_NEW_EN)


[grammar] ~134-~134: There might be a mistake here.
Context: ...ition_embeddings, } ) ``` Then, we convert model to a speculative decoding model: ...

(QB_NEW_EN)


[grammar] ~140-~140: There might be a mistake here.
Context: ... training forward, making it compatible with HF trainer: ```python # Create a train...

(QB_NEW_EN)


[grammar] ~170-~170: There might be a mistake here.
Context: ...Q and QAT. ### Model Validation After training draft model, we can evaluate the saved ...

(QB_NEW_EN)


[grammar] ~229-~229: There might be a mistake here.
Context: ... | Model | Medusa | EAGLE1/2 | EAGLE3 | | :---: | :---: | :---: | :---: | | LLAM...

(QB_NEW_EN)


[grammar] ~230-~230: There might be a mistake here.
Context: ...GLE3 | | :---: | :---: | :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 |...

(QB_NEW_EN)


[grammar] ~231-~231: There might be a mistake here.
Context: ... :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral |...

(QB_NEW_EN)


[grammar] ~232-~232: There might be a mistake here.
Context: ...✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ ...

(QB_NEW_EN)


[grammar] ~233-~233: There might be a mistake here.
Context: ....1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 |...

(QB_NEW_EN)


[grammar] ~234-~234: There might be a mistake here.
Context: ...tral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | ## Specu...

(QB_NEW_EN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
examples/speculative_decoding/README.md (6)

5-6: Clear, technically accurate intro.

The γ/α explanation is concise and correct.


227-236: Support matrix may be outdated.

A reviewer noted more models are now supported; update once verified on this HF path.


239-241: Nice addition.

Checkpoint collection link and deployment note are helpful.


71-75: No action needed: script path is correct
server_generate.py is present at examples/speculative_decoding/server_generate.py, matching the README’s invocation when run from that directory.


194-205: Validate TRT-LLM serve flags and adjust wording.

Most CLI flags (--backend pytorch, --max_batch_size, --max_num_tokens, --max_seq_len, --extra_llm_api_options) and YAML keys (enable_attention_dp, disable_overlap_scheduler, cuda_graph_config, speculative_config.*) match current TRT-LLM docs. I couldn’t locate enable_autotuner—please confirm it’s supported in your target release or remove it.

-, with `extra-llm-api-config.yml` being
+With `extra-llm-api-config.yml`:

62-70: vLLM flags confirmed

  • --api-key is optional for local runs; if provided (or via VLLM_API_KEY), the server will enforce the specified key(s) in request headers.
  • --quantization=modelopt is supported as one of vLLM’s quantization backends.

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-readme branch from ed1f5a3 to 240eac1 Compare September 6, 2025 02:02
@ChenhanYu ChenhanYu self-requested a review September 6, 2025 15:06
@ChenhanYu
Copy link
Collaborator

@yeyu-nvidia
Copy link
Contributor

This needs to be addressed #293 (comment). Otherwise, LGTM.

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-readme branch 2 times, most recently from 023b27a to faf36b7 Compare September 8, 2025 17:23
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (11)
examples/speculative_decoding/launch_train.sh (1)

87-91: Guard torch detection and prevent division-by-zero; prefer NUM_GPU when available.

python -c 'import torch' can fail under set -e, and torch.cuda.device_count() may return 0, yielding a division-by-zero or DEFAULT_SAVE_STEPS=0. Use ${NUM_GPU} when valid, fall back robustly to 1, and clamp to min 1.

-# Get the default value for save_steps based on the available number of GPUs
-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+# Get the default value for save_steps based on the available number of GPUs
+# Prefer NUM_GPU if provided; else try torch; finally fall back to 1.
+if [[ -n "${NUM_GPU:-}" && "${NUM_GPU}" =~ ^[0-9]+$ && ${NUM_GPU} -gt 0 ]]; then
+  GPU_COUNT=${NUM_GPU}
+else
+  GPU_COUNT=$(python - <<'PY'
+try:
+  import torch
+  print(torch.cuda.device_count() or 0)
+except Exception:
+  print(0)
+PY
+  ) || GPU_COUNT=0
+  [[ -z "$GPU_COUNT" || "$GPU_COUNT" -le 0 ]] && GPU_COUNT=1
+fi
+# Calculate save_steps safely (min 1)
+DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+(( DEFAULT_SAVE_STEPS < 1 )) && DEFAULT_SAVE_STEPS=1
modelopt/torch/export/plugins/hf_spec_export.py (6)

18-20: Record transformers version (best-effort) for traceability.

Populate transformers_version to aid downstream tooling.

@@
-import torch
+import torch
 import torch.nn as nn
+try:
+    import transformers as _hf
+    _TRANSFORMERS_VERSION = getattr(_hf, "__version__", None)
+except Exception:
+    _TRANSFORMERS_VERSION = None

76-79: Harden lm_head fallback; raise if neither drafter nor base provides it.

Prevents silent KeyError from model.state_dict()["lm_head.weight"].

@@
-    if "eagle_lm_head.weight" not in eagle_state:
-        export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
+    if "eagle_lm_head.weight" not in eagle_state:
+        base_state = model.state_dict()
+        if "lm_head.weight" in base_state:
+            export_state_dict["lm_head.weight"] = base_state["lm_head.weight"]
+        else:
+            raise KeyError("Missing 'eagle_lm_head.weight' in drafter and 'lm_head.weight' in base model.")

127-134: Guard eagle_config/config access before getattr.

Avoids AttributeError when either container is missing.

@@
-    def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
-        if getattr(model.eagle_config, key, None) is not None:
-            return getattr(model.eagle_config, key)
-        elif getattr(model.config, key, None) is not None:
-            return getattr(model.config, key)
-        else:
-            return None
+    def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
+        eagle_cfg = getattr(model, "eagle_config", None)
+        if eagle_cfg is not None and getattr(eagle_cfg, key, None) is not None:
+            return getattr(eagle_cfg, key)
+        base_cfg = getattr(model, "config", None)
+        if base_cfg is not None and getattr(base_cfg, key, None) is not None:
+            return getattr(base_cfg, key)
+        return None

89-108: Preserve original config fields and merge; set transformers_version.

Avoid dropping unknown config keys and persist the detected Transformers version.

@@
-        "transformers_version": None,
+        "transformers_version": _TRANSFORMERS_VERSION,
@@
-    return template_config
+    # Merge with original to preserve unknown keys.
+    merged = {**config_data, **template_config}
+    merged["eagle_config"] = {
+        **config_data.get("eagle_config", {}),
+        **template_config["eagle_config"],
+    }
+    return merged

Also applies to: 149-149


84-88: set_config_if_spec_decoding crashes on vanilla HF models.

Direct len(model._modelopt_state) will raise when the attribute is absent.

@@
-    if len(model._modelopt_state) != 1 or model._modelopt_state[0][0] != "eagle":
+    opt_modes = getattr(model, "_modelopt_state", None)
+    if (
+        not isinstance(opt_modes, (list, tuple))
+        or len(opt_modes) != 1
+        or opt_modes[0][0] != "eagle"
+    ):
         # return as is
         return config_data

51-65: Guard for missing eagle_module to avoid AttributeError.

Accessing model.eagle_module without a presence check can crash on non-Eagle models.

@@
-    # Check if the state dict keys match
-    _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
+    # Ensure eagle_module exists
+    if not hasattr(model, "eagle_module"):
+        return post_state_dict
+    # Check if the state dict keys match
+    _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
modelopt/torch/export/unified_export_hf.py (4)

494-498: Fix _quant_applied logic; make it reliable.

Current check can return True for empty configs and is used nowhere. Tighten and use it.

@@
-def _quant_applied(hf_quant_config: dict) -> bool:
-    """Check if any quantization is applied."""
-    q = hf_quant_config.get("quantization", {})
-    return not (q.get("quant_algo") == QUANTIZATION_NONE and not q.get("quantized_layers"))
+def _quant_applied(hf_quant_config: dict) -> bool:
+    """Return True iff any quantization is configured."""
+    q = hf_quant_config.get("quantization") or {}
+    algo = q.get("quant_algo")
+    layers = q.get("quantized_layers")
+    return (algo is not None and algo != QUANTIZATION_NONE) or bool(layers)

519-523: Don’t emit empty hf_quant_config.json.

Only write when quantization is actually applied; also use Path for portability.

@@
-        # Save hf_quant_config.json for backward compatibility
-        with open(f"{export_dir}/hf_quant_config.json", "w") as file:
-            json.dump(hf_quant_config, file, indent=4)
+        # Save hf_quant_config.json only when quantization is applied
+        if _quant_applied(hf_quant_config):
+            with open(Path(export_dir) / "hf_quant_config.json", "w") as file:
+                json.dump(hf_quant_config, file, indent=4)

526-527: Guard spec-decoding post-state transform when ModelOpt state is absent.

Prevents AttributeError on vanilla HF models.

@@
-        post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
+        if hasattr(model, "_modelopt_state"):
+            post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)

539-544: Preserve quantization_config and avoid spec-config clobbering; gate on quant applied.

Move assignment after spec-config transform and gate it. Also remove stale key when no-quant.

@@
-        config_data["quantization_config"] = hf_quant_config
-
-        config_data = set_config_if_spec_decoding(model, config_data)
+        if hasattr(model, "_modelopt_state"):
+            config_data = set_config_if_spec_decoding(model, config_data)
+        if _quant_applied(hf_quant_config):
+            config_data["quantization_config"] = hf_quant_config
+        else:
+            config_data.pop("quantization_config", None)
🧹 Nitpick comments (4)
examples/speculative_decoding/launch_train.sh (4)

111-116: Quote config path to handle spaces and special chars.

Unquoted paths can break argument parsing.

-    SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
+    SPECULATIVE_ARGS="--eagle_config \"$EAGLE_CONFIG\""

130-153: Quote key CLI values to be robust to spaces.

Quote MODEL, OUTPUT_DIR, and DATA in the assembled command.

-    --model_name_or_path $MODEL \
+    --model_name_or_path \"$MODEL\" \
@@
-    --output_dir $OUTPUT_DIR \
+    --output_dir \"$OUTPUT_DIR\" \
@@
-    --data_path $DATA \
+    --data_path \"$DATA\" \

103-106: Remove or wire unused variables.

REDRAFTER_TOKENS and REDRAFTER_NUM_LAYERS are set but unused.


77-80: Improve invalid-argument error message.

Print the full offending token; current expansion may mangle it.

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument: %s\n" "$1"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d416ba9 and faf36b7.

📒 Files selected for processing (14)
  • examples/speculative_decoding/README.md (2 hunks)
  • examples/speculative_decoding/SLURM_prepare_data.md (1 hunks)
  • examples/speculative_decoding/ar_validate.py (2 hunks)
  • examples/speculative_decoding/calibrate_draft_vocab.py (2 hunks)
  • examples/speculative_decoding/eagle_config.json (1 hunks)
  • examples/speculative_decoding/export_hf_checkpoint.py (1 hunks)
  • examples/speculative_decoding/launch_train.sh (1 hunks)
  • examples/speculative_decoding/main.py (3 hunks)
  • examples/speculative_decoding/server_generate.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (1 hunks)
  • modelopt/torch/export/plugins/__init__.py (1 hunks)
  • modelopt/torch/export/plugins/hf_spec_export.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (4 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
✅ Files skipped from review due to trivial changes (1)
  • examples/speculative_decoding/SLURM_prepare_data.md
🚧 Files skipped from review as they are similar to previous changes (10)
  • examples/speculative_decoding/ar_validate.py
  • examples/speculative_decoding/calibrate_draft_vocab.py
  • examples/speculative_decoding/eagle_config.json
  • modelopt/torch/export/plugins/init.py
  • examples/speculative_decoding/server_generate.py
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/README.md
  • examples/speculative_decoding/export_hf_checkpoint.py
  • examples/speculative_decoding/train_eagle3_and_export.sh
  • modelopt/torch/speculative/plugins/transformers.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • modelopt/torch/export/plugins/hf_spec_export.py
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • modelopt/torch/export/plugins/hf_spec_export.py
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • modelopt/torch/export/plugins/hf_spec_export.py
🧬 Code graph analysis (1)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/export/plugins/hf_spec_export.py (2)
  • rename_and_prune_if_spec_decoding (51-80)
  • set_config_if_spec_decoding (83-149)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
examples/speculative_decoding/launch_train.sh (1)

151-152: Validate required --data early

Insert immediately after export TOKENIZERS_PARALLELISM=False in examples/speculative_decoding/launch_train.sh:

if [[ -z "${DATA:-}" ]]; then
  echo "Error: --data PATH is required." >&2
  exit 1
fi
modelopt/torch/export/plugins/hf_spec_export.py (1)

54-61: Nice: tolerant gate on _modelopt_state in rename path.

Good defensive checks; keeps non-Eagle paths unaffected.

modelopt/torch/export/unified_export_hf.py (1)

56-56: Import placement is good; plugin usage clear.

Spec-decoding integration point is well-scoped.

Comment on lines +69 to +76
--fsdp_transformer_layer_cls_to_wrap*)
if [[ "$1" != *=* ]]; then shift; fi
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
;;
--num_gpu*)
if [[ "$1" != *=* ]]; then shift; fi
NUM_GPU="${1#*=}"
;;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Parsed FSDP arg is never forwarded to the training command.

--fsdp_transformer_layer_cls_to_wrap is parsed but not passed to main.py, so user input is ignored.

@@
     --num_gpu*)
       if [[ "$1" != *=* ]]; then shift; fi
       NUM_GPU="${1#*=}"
       ;;
@@
 fi
@@
 export TOKENIZERS_PARALLELISM=False
+FSDP_ARGS=""
+if [[ -n "${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-}" ]]; then
+  FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap ${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP}"
+fi
 CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
@@
     --data_path $DATA \
-    $SPECULATIVE_ARGS
+    $SPECULATIVE_ARGS \
+    $FSDP_ARGS
 "

Also applies to: 130-153

Comment on lines +122 to +127
if [[ "$NUM_GPU" == 1 ]]; then
MULTI_GPU=""
else
MULTI_GPU="--multi_gpu"
fi

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

❓ Verification inconclusive

Multi-GPU launch: pass explicit process count to Accelerate; --multi_gpu may be ignored on newer versions.

Safer to specify --num_processes "${NUM_GPU}" and drop the custom flag.

-if [[ "$NUM_GPU" == 1 ]]; then
-  MULTI_GPU=""
-else
-  MULTI_GPU="--multi_gpu"
-fi
+LAUNCH_OPTS="--mixed_precision bf16"
+if [[ "${NUM_GPU}" -gt 1 ]]; then
+  LAUNCH_OPTS+=" --num_processes ${NUM_GPU}"
+fi
@@
-CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+CMD="accelerate launch $LAUNCH_OPTS main.py \

Also applies to: 130-131


Use explicit --num_processes instead of --multi_gpu for multi-GPU runs

Accelerate launch supports --num_processes=<N> alone to spawn N GPUs (and implicitly use MULTI_GPU) without requiring --multi_gpu (huggingface.co, modeldatabase.com)

-if [[ "$NUM_GPU" == 1 ]]; then
-  MULTI_GPU=""
-else
-  MULTI_GPU="--multi_gpu"
-fi
+LAUNCH_OPTS="--mixed_precision bf16"
+if [[ "${NUM_GPU}" -gt 1 ]]; then
+  LAUNCH_OPTS+=" --num_processes ${NUM_GPU}"
+fi
@@
-CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+CMD="accelerate launch $LAUNCH_OPTS main.py \

Also update the same pattern at lines 130–131.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/speculative_decoding/launch_train.sh around lines 122–127 (and also
update the same pattern at lines 130–131), replace the current multi-GPU flag
logic that sets MULTI_GPU="--multi_gpu" with an explicit process-count flag:
when NUM_GPU==1 keep MULTI_GPU empty, otherwise set
MULTI_GPU="--num_processes=$NUM_GPU"; update any subsequent invocations that
previously relied on --multi_gpu to use this MULTI_GPU variable so Accelerate is
launched with --num_processes=<N> instead of --multi_gpu.

@h-guo18
Copy link
Contributor Author

h-guo18 commented Sep 8, 2025

/ok to test faf36b7

@h-guo18
Copy link
Contributor Author

h-guo18 commented Sep 8, 2025

/ok to test 6b2411e

import torch.nn as nn

EAGLE_MODELOPT_TO_OFFICIAL = {
"required": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this from megatron to HF?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the param name mapping from modelopt::HFEagleModel to eagle3 official checkpoint. I bielieve both megatron and HF export to the same format.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter is also consistent with our previous released eagle checkpoints.

@h-guo18 h-guo18 force-pushed the haoguo/update-eagle-readme branch from 6b2411e to 854564e Compare September 8, 2025 23:30
@h-guo18
Copy link
Contributor Author

h-guo18 commented Sep 8, 2025

/ok to test 854564e

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/speculative_decoding/server_generate.py (2)

138-154: Use the Completions API in non-chat mode (current call is inconsistent and likely broken).

You're passing prompt= to client.chat.completions.create and later reading .text. Switch to client.completions.create and use the returned .text.

-            response = client.chat.completions.create(
+            response = client.completions.create(
                 model=model_name,
                 prompt=prompt,
                 max_tokens=args.max_tokens,
                 temperature=args.temperature,
                 ignore_eos=False,
                 skip_special_tokens=False,
                 spaces_between_special_tokens=False,
             )
-            response = response.choices[0].text.strip()
+            response_text = response.choices[0].text.strip()
             with open(args.output_path, "a") as f:
                 # write in share gpt format
                 if args.log_empty_conversations:
-                    to_write = {"conversation_id": idx, "text": prompt + response}
+                    to_write = {"conversation_id": idx, "text": prompt + response_text}
                 else:
-                    to_write = {"text": prompt + response}
+                    to_write = {"text": prompt + response_text}
                 f.write(json.dumps(to_write) + "\n")

155-159: Avoid UnboundLocalError when printing prompt in exceptions.

prompt is undefined in chat mode or when the error occurs before assignment.

-    except Exception as e:
-        print(e)
-        print(prompt)
-        print("Failed to generate data")
+    except Exception as e:
+        print(e)
+        if "prompt" in locals():
+            print(prompt)
+        print("Failed to generate data")
♻️ Duplicate comments (6)
examples/speculative_decoding/server_generate.py (1)

49-49: Fix boolean CLI parsing for --chat.

type=bool treats any non-empty string (even "False") as True. Use BooleanOptionalAction so --chat/--no-chat work correctly.

-parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
+parser.add_argument(
+    "--chat",
+    action=argparse.BooleanOptionalAction,
+    default=True,
+    help="Use chat mode (use --no-chat to disable)",
+)
modelopt/torch/export/plugins/hf_spec_export.py (4)

89-125: Preserve original config_data fields and set transformers_version; deep-merge eagle_config.

Current code replaces the entire config, dropping unknown fields. Merge defaults into the original and record the Transformers version.

@@
-    # This is the config keys in official checkpoint.
+    # This is the config keys in official checkpoint.
     template_config = {
@@
-        "transformers_version": None,
+        "transformers_version": getattr(transformers, "__version__", None),
@@
-    for key in template_config:
+    for key in template_config:
         value = template_config[key]
@@
-            template_config[key] = new_value
+            template_config[key] = new_value
@@
-    return template_config
+    # Merge: keep any unknown keys from original config_data
+    merged = {**config_data, **template_config}
+    # Deep-merge nested eagle_config
+    merged["eagle_config"] = {
+        **config_data.get("eagle_config", {}),
+        **template_config["eagle_config"],
+    }
+    return merged

Also applies to: 135-149


76-79: Harden fallback for lm_head.weight export.

Explicitly check base model state and raise a clear error if neither key exists.

-    # TODO: (hg) this is a temp fix. Find cleaner way to do this.
-    if "eagle_lm_head.weight" not in eagle_state:
-        export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
+    # TODO: (hg) this is a temp fix. Find cleaner way to do this.
+    if "eagle_lm_head.weight" not in eagle_state:
+        base_state = model.state_dict()
+        if "lm_head.weight" in base_state:
+            export_state_dict["lm_head.weight"] = base_state["lm_head.weight"]
+        else:
+            raise KeyError(
+                "Missing 'eagle_lm_head.weight' in draft and 'lm_head.weight' in base model."
+            )

127-134: Guard access to eagle_config/config to avoid AttributeError.

-    def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
-        if getattr(model.eagle_config, key, None) is not None:
-            return getattr(model.eagle_config, key)
-        elif getattr(model.config, key, None) is not None:
-            return getattr(model.config, key)
-        else:
-            return None
+    def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
+        eagle_cfg = getattr(model, "eagle_config", None)
+        if eagle_cfg is not None and getattr(eagle_cfg, key, None) is not None:
+            return getattr(eagle_cfg, key)
+        base_cfg = getattr(model, "config", None)
+        if base_cfg is not None and getattr(base_cfg, key, None) is not None:
+            return getattr(base_cfg, key)
+        return None

51-66: Guard for Eagle mode and presence of eagle_module before accessing it.

If _modelopt_state indicates Eagle but eagle_module is missing, this raises AttributeError. Early-return safely.

 def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
@@
-    opt_modes = getattr(model, "_modelopt_state", None)
+    opt_modes = getattr(model, "_modelopt_state", None)
     if (
         not isinstance(opt_modes, (list, tuple))
         or len(opt_modes) != 1
         or opt_modes[0][0] != "eagle"
     ):
         # if there's other opts, return as is
         return post_state_dict
+    if not hasattr(model, "eagle_module"):
+        # Not an Eagle-wrapped model; return unchanged
+        return post_state_dict
@@
-    _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
+    _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
examples/speculative_decoding/README.md (1)

28-33: Provide concrete install commands (replace placeholder).

Make this copy-pastable; include PyPI install and editable source install with HF extras.

-Install Modelopt with `hf` dependencies and other requirements for this example:
+Install ModelOpt with `hf` extras and example requirements:
@@
-```bash
-pip install -e ...
-pip install -r requirements.txt
-```
+```bash
+# Option A: Install published package (recommended)
+pip install "nvidia-modelopt[hf]"
+
+# Option B: Install from source
+pip install -e .[hf]
+
+# Example-specific deps
+pip install -r requirements.txt
+```
🧹 Nitpick comments (6)
examples/speculative_decoding/server_generate.py (3)

56-57: Use list default for --system_prompt to avoid odd joins.

Defaulting to "" makes " ".join(args.system_prompt) iterate characters. Use an empty list.

-parser.add_argument("--system_prompt", nargs="+", type=str, default="", help="System prompt")
+parser.add_argument("--system_prompt", nargs="+", type=str, default=[], help="System prompt")

43-44: Lower default thread count to a safer value.

256 workers can easily overwhelm local servers and hit rate limits. Suggest 32–64 by default.

-parser.add_argument(
-    "--num_threads", type=int, default=256, help="Number of threads to use (batch size)"
+parser.add_argument(
+    "--num_threads", type=int, default=64, help="Number of threads to use (batch size)"
 )

60-65: Use context manager for JSON load for consistency and file handle safety.

-if args.data_path.endswith("jsonl"):
-    with open(args.data_path) as f:
-        data = [json.loads(line) for line in f]
-else:
-    data = json.load(open(args.data_path))
+if args.data_path.endswith("jsonl"):
+    with open(args.data_path) as f:
+        data = [json.loads(line) for line in f]
+else:
+    with open(args.data_path) as f:
+        data = json.load(f)
modelopt/torch/export/plugins/hf_spec_export.py (1)

18-20: Record Transformers version for provenance.

Import transformers to populate transformers_version in the exported config.

 import torch
 import torch.nn as nn
+import transformers
examples/speculative_decoding/README.md (2)

47-48: Brand consistency: “ModelOpt” capitalization.

Use “ModelOpt” consistently.

-This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
+This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in ModelOpt. Specifically, it

229-236: Validate or trim the Support Matrix to avoid drift.

Section can get stale quickly; either confirm the list against current HF path support or drop the table and link to dynamic docs.

Would you like me to open a follow-up PR to auto-generate this table from a tested registry?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6b2411e and 854564e.

📒 Files selected for processing (14)
  • examples/speculative_decoding/README.md (2 hunks)
  • examples/speculative_decoding/SLURM_prepare_data.md (1 hunks)
  • examples/speculative_decoding/ar_validate.py (2 hunks)
  • examples/speculative_decoding/calibrate_draft_vocab.py (2 hunks)
  • examples/speculative_decoding/eagle_config.json (1 hunks)
  • examples/speculative_decoding/export_hf_checkpoint.py (1 hunks)
  • examples/speculative_decoding/launch_train.sh (1 hunks)
  • examples/speculative_decoding/main.py (3 hunks)
  • examples/speculative_decoding/server_generate.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (1 hunks)
  • modelopt/torch/export/plugins/__init__.py (1 hunks)
  • modelopt/torch/export/plugins/hf_spec_export.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (3 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (11 hunks)
🚧 Files skipped from review as they are similar to previous changes (11)
  • examples/speculative_decoding/eagle_config.json
  • examples/speculative_decoding/export_hf_checkpoint.py
  • examples/speculative_decoding/SLURM_prepare_data.md
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/ar_validate.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/export/plugins/init.py
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/train_eagle3_and_export.sh
  • examples/speculative_decoding/calibrate_draft_vocab.py
  • modelopt/torch/speculative/plugins/transformers.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • modelopt/torch/export/plugins/hf_spec_export.py
  • examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • modelopt/torch/export/plugins/hf_spec_export.py
  • examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.

Applied to files:

  • modelopt/torch/export/plugins/hf_spec_export.py
  • examples/speculative_decoding/README.md
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality

@h-guo18
Copy link
Contributor Author

h-guo18 commented Sep 8, 2025

/ok to test 854564e

@h-guo18 h-guo18 merged commit a6fa34c into main Sep 9, 2025
24 of 27 checks passed
@h-guo18 h-guo18 deleted the haoguo/update-eagle-readme branch September 9, 2025 01:02
benchislett pushed a commit that referenced this pull request Sep 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet