Skip to content

Conversation

h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Sep 26, 2025

What does this PR do?

Type of change: Bug Fix

Overview:

Fixed 2 bugs when training with quantized base model:

  • Some base model modules are not loaded but present in _keep_in_fp32_modules , raising an error here when loading model

    • Solution: Instead of setting num_hidden_layer=0 to from_pretrained() to load partial base model, we load the entire base model to cpu first, then delete base_model.layers during convert().
  • The base model is considered "purely quantized models" in HF trainer, therefore cannot be fine-tuned. Error raised here when initializing trainer.

    • Solution: setting base_model.is_quantized=False during model conversion;

Other minor polishes/fix:

Usage

Unchanged

# Add a code snippet demonstrating how to use this

Testing

Tested with dummy training on both onlin/offline and quantized/unquantized settings:

  • TinyLlama, online
  • TinyLlama, offline
  • GPTOSS-20b, online
  • GPTOSS-120b, offline

Saw training, evaluation and checkpoint saving with not error, and loss decreasing.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a training callback for periodic AR validation, configurable by step interval, with optional metric logging.
  • Improvements

    • Enhanced offline mode reliability by adapting to missing base layers and disabling incompatible checks.
    • Speculative generation now enforces SDPA attention during EAGLE forward passes for stability.
    • Model loading defaults to CPU device mapping for broader compatibility.
    • Reduced overhead by collecting auxiliary hidden states only when online.
  • Changes

    • Removed automatic Weights & Biases initialization; set it up manually if needed.
  • Refactor

    • Introduced a safe mechanism for temporary config overrides.

Copy link

copy-pr-bot bot commented Sep 26, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 26, 2025

Walkthrough

Adds an AR validation TrainerCallback and wires it into speculative decoding examples; consolidates callback into eagle_utils and removes wandb usage from main. Introduces a temporary config context manager and updates the transformers plugin to handle offline mode (layer removal, quantization flags) and to enforce SDPA during EAGLE forward via the new context manager.

Changes

Cohort / File(s) Summary
AR validation callback integration
examples/speculative_decoding/eagle_utils.py, examples/speculative_decoding/main.py
Introduces ARValidationCallback in eagle_utils.py (runs validate_ar every N steps; optional wandb). Replaces local callback in main.py with import; removes wandb setup and TrainerCallback export; keeps Trainer wiring. Note: duplicate class definition present in eagle_utils.py.
Speculative transformers plugin: offline + SDPA enforcement
modelopt/torch/speculative/plugins/transformers.py
Imports temporary_set_config_value. In offline mode: removes base model layers and forces is_quantized=False; skips EAGLE-3 aux hidden states. Wraps EAGLE forward with temporary SDPA attention selection. Adapts flow for absence of base layers.
Config utility context manager
modelopt/torch/speculative/utils.py
Adds temporary_set_config_value(config, field, value) context manager to temporarily set and restore config attributes; validates field existence; uses contextlib.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Trainer
  participant ARValidationCallback as ARValidationCallback
  participant validate_ar as validate_ar()
  participant Datasets as datasets.load_dataset
  participant W&B as wandb (optional)

  User->>Trainer: start training
  loop Every step
    Trainer->>ARValidationCallback: on_step_end(state.global_step)
    alt ar_validate_steps > 0 and step % N == 0
      ARValidationCallback->>Datasets: load_dataset(eval_prompts)
      ARValidationCallback->>validate_ar: model, tokenizer, dataset, device
      validate_ar-->>ARValidationCallback: AR score
      ARValidationCallback->>Trainer: log via stdout
      opt wandb available
        ARValidationCallback->>W&B: log AR score
      end
    else skip
    end
  end
Loading
sequenceDiagram
  autonumber
  participant Gen as SpeculativeGeneration
  participant Plugin as transformers plugin
  participant Model as Model
  participant Config as model.config
  participant Utils as temporary_set_config_value()

  Note over Plugin,Model: Initialization / Offline mode
  Plugin->>Model: if eagle_offline: pop "layers"
  Plugin->>Model: set is_quantized = False
  Plugin->>Plugin: gate EAGLE-3 aux states when offline

  Note over Gen,Model: EAGLE forward pass with SDPA
  Gen->>Plugin: eagle_forward(...)
  Plugin->>Utils: with temporary_set_config_value(Config, "attn_implementation", "sdpa")
  activate Utils
  Utils-->>Plugin: apply temp config
  Plugin->>Model: forward(...)
  Model-->>Plugin: outputs
  Utils-->>Plugin: restore original config
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

A hop, a skip, through configs I glide,
SDPA winds at my whiskered side.
AR bells ring every thousand beats—
Offline burrows, trimmed of layers’ sheets.
With context spells that set, then free,
I nibble bugs—fast, happily.
(_/) ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly highlights the primary purpose of the pull request—the bug fix for the EAGLE3 quantized base model—and aligns with the two main objectives of correcting base‐model loading and quantization flags without extraneous detail. It is concise, specific, and clearly signals the most important change to reviewers.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch haoguo/fix-eagle3-quantized-basemodel

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🧪 Early access (Sonnet 4.5): enabled

We are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience.

Note:

  • Public repositories are always opted into early access features.
  • You can enable or disable early access features from the CodeRabbit UI or by updating the CodeRabbit configuration file.

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 self-assigned this Sep 26, 2025
@h-guo18 h-guo18 requested a review from yeyu-nvidia September 26, 2025 21:10
Copy link

codecov bot commented Sep 26, 2025

Codecov Report

❌ Patch coverage is 33.33333% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.84%. Comparing base (c9db0ce) to head (ae2aa93).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/utils.py 33.33% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #383      +/-   ##
==========================================
- Coverage   73.86%   73.84%   -0.03%     
==========================================
  Files         171      171              
  Lines       17629    17638       +9     
==========================================
+ Hits        13021    13024       +3     
- Misses       4608     4614       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 force-pushed the haoguo/fix-eagle3-quantized-basemodel branch from 38f8ada to ae2aa93 Compare September 29, 2025 23:01
@h-guo18
Copy link
Contributor Author

h-guo18 commented Sep 29, 2025

/ok to test ae2aa93

@h-guo18 h-guo18 marked this pull request as ready for review September 29, 2025 23:02
@h-guo18 h-guo18 requested a review from a team as a code owner September 29, 2025 23:02
@h-guo18 h-guo18 requested a review from ChenhanYu September 29, 2025 23:02
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
modelopt/torch/speculative/plugins/transformers.py (1)

454-456: Consider the fragility of using _modules.pop().

Accessing _modules is using PyTorch's internal implementation details. While this fixes the offline training issue, consider whether there's a more stable API for removing submodules. The approach is acceptable given the constraints but may require maintenance if PyTorch internals change.

examples/speculative_decoding/eagle_utils.py (1)

397-415: Add safety check for empty AR results.

If validate_ar returns an empty list (e.g., dataset has no samples), the calculation sum(ars) / len(ars) will raise a ZeroDivisionError.

Apply this diff to add a safety check:

             print_rank_0("Running AR validation...")
             ars = validate_ar(
                 model=kwargs["model"],
                 tokenizer=kwargs["processing_class"],
                 ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
                 device=kwargs["model"].device,
             )
-            print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
-            if wandb:
-                wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
+            if ars:
+                avg_ar = sum(ars) / len(ars)
+                print_rank_0(f"Step {state.global_step} AR: {avg_ar:.4f}")
+                if wandb and wandb.run:
+                    wandb.log({"validate_ar": avg_ar}, step=state.global_step)
+            else:
+                print_rank_0(f"Step {state.global_step} AR validation returned no results")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c9db0ce and ae2aa93.

📒 Files selected for processing (4)
  • examples/speculative_decoding/eagle_utils.py (2 hunks)
  • examples/speculative_decoding/main.py (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (3 hunks)
  • modelopt/torch/speculative/utils.py (2 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-29T17:40:37.261Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.261Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.

Applied to files:

  • modelopt/torch/speculative/plugins/transformers.py
📚 Learning: 2025-09-17T14:14:44.961Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.961Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.

Applied to files:

  • modelopt/torch/speculative/plugins/transformers.py
🧬 Code graph analysis (3)
modelopt/torch/speculative/plugins/transformers.py (3)
modelopt/torch/speculative/utils.py (1)
  • temporary_set_config_value (369-378)
modelopt/onnx/quantization/graph_utils.py (1)
  • layer_idx (781-783)
modelopt/torch/quantization/utils.py (1)
  • is_quantized (239-243)
examples/speculative_decoding/eagle_utils.py (2)
examples/speculative_decoding/ar_validate.py (1)
  • validate_ar (29-51)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
examples/speculative_decoding/main.py (2)
examples/speculative_decoding/eagle_utils.py (2)
  • ARValidationCallback (397-415)
  • make_eagle_supervised_data_module (248-325)
examples/speculative_decoding/medusa_utils.py (1)
  • make_medusa_supervised_data_module (196-219)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (8)
modelopt/torch/speculative/utils.py (1)

368-378: LGTM! Well-implemented context manager.

The temporary_set_config_value context manager is correctly implemented with proper validation, exception handling, and guaranteed restoration of the original value via the finally block.

modelopt/torch/speculative/plugins/transformers.py (3)

918-928: Good use of the context manager for SDPA enforcement.

The temporary config switch to SDPA attention during pseudo_speculative_generate is well-implemented using the new context manager. The PR description mentions this avoids compile errors and improves performance compared to flex_attn.


448-452: Aux hidden state collection guard is correct. The offline dataset in examples/speculative_decoding/eagle_utils.py provides aux_hidden_states, so gating hook registration with not eagle_offline prevents errors and requires no changes.


454-460: No unintended is_quantized attribute usage found. Verified that the only occurrence of self.is_quantized is the hack in this plugin; all quantization code uses the is_quantized(module) function, not this attribute. Forcing it to False solely bypasses the HuggingFace Trainer check as intended.

examples/speculative_decoding/main.py (4)

39-41: LGTM! Import consolidation improves modularity.

Moving ARValidationCallback to eagle_utils improves code organization by consolidating related utilities.


142-142: LGTM! CPU loading prevents initialization issues.

Loading the model to CPU first aligns with the fix for partial base-model loading errors mentioned in the PR objectives.


144-148: LGTM! Preserves layer count for offline training.

The logic correctly saves the original num_hidden_layers before conversion, which is necessary when layers are removed during offline training initialization.


228-228: LGTM! Callback integration is correct.

The ARValidationCallback is properly instantiated with the configured validation interval.

Comment on lines +30 to +35
try:
import wandb

wandb.init()
except ImportError:
wandb = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Premature wandb.init() call in import block.

Calling wandb.init() at module import time (line 33) is problematic because:

  1. It initializes wandb even if ARValidationCallback is never used
  2. It happens before the user can configure wandb settings
  3. It may conflict with other wandb initialization in the application

Apply this diff to defer initialization:

 try:
     import wandb
-
-    wandb.init()
 except ImportError:
     wandb = None

Then in the callback, check if wandb is initialized before logging:

             print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
-            if wandb:
+            if wandb and wandb.run:
                 wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
import wandb
wandb.init()
except ImportError:
wandb = None
# At the top of examples/speculative_decoding/eagle_utils.py, adjust the wandb import block:
try:
import wandb
except ImportError:
wandb = None
# … later in the file, inside ARValidationCallback (around the print_rank_0 call):
def on_evaluate(self, args, state, control, metrics=None, logs=None):
# existing metric computation…
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb and wandb.run:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
# rest of callback…

@h-guo18 h-guo18 merged commit c55bcf0 into main Sep 30, 2025
27 checks passed
@h-guo18 h-guo18 deleted the haoguo/fix-eagle3-quantized-basemodel branch September 30, 2025 01:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants