-
Notifications
You must be signed in to change notification settings - Fork 169
fix: eagle3 quantized base model #383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughAdds an AR validation TrainerCallback and wires it into speculative decoding examples; consolidates callback into eagle_utils and removes wandb usage from main. Introduces a temporary config context manager and updates the transformers plugin to handle offline mode (layer removal, quantization flags) and to enforce SDPA during EAGLE forward via the new context manager. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Trainer
participant ARValidationCallback as ARValidationCallback
participant validate_ar as validate_ar()
participant Datasets as datasets.load_dataset
participant W&B as wandb (optional)
User->>Trainer: start training
loop Every step
Trainer->>ARValidationCallback: on_step_end(state.global_step)
alt ar_validate_steps > 0 and step % N == 0
ARValidationCallback->>Datasets: load_dataset(eval_prompts)
ARValidationCallback->>validate_ar: model, tokenizer, dataset, device
validate_ar-->>ARValidationCallback: AR score
ARValidationCallback->>Trainer: log via stdout
opt wandb available
ARValidationCallback->>W&B: log AR score
end
else skip
end
end
sequenceDiagram
autonumber
participant Gen as SpeculativeGeneration
participant Plugin as transformers plugin
participant Model as Model
participant Config as model.config
participant Utils as temporary_set_config_value()
Note over Plugin,Model: Initialization / Offline mode
Plugin->>Model: if eagle_offline: pop "layers"
Plugin->>Model: set is_quantized = False
Plugin->>Plugin: gate EAGLE-3 aux states when offline
Note over Gen,Model: EAGLE forward pass with SDPA
Gen->>Plugin: eagle_forward(...)
Plugin->>Utils: with temporary_set_config_value(Config, "attn_implementation", "sdpa")
activate Utils
Utils-->>Plugin: apply temp config
Plugin->>Model: forward(...)
Model-->>Plugin: outputs
Utils-->>Plugin: restore original config
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🧪 Early access (Sonnet 4.5): enabledWe are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience. Note:
Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #383 +/- ##
==========================================
- Coverage 73.86% 73.84% -0.03%
==========================================
Files 171 171
Lines 17629 17638 +9
==========================================
+ Hits 13021 13024 +3
- Misses 4608 4614 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <[email protected]>
38f8ada
to
ae2aa93
Compare
/ok to test ae2aa93 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
modelopt/torch/speculative/plugins/transformers.py (1)
454-456
: Consider the fragility of using _modules.pop().Accessing
_modules
is using PyTorch's internal implementation details. While this fixes the offline training issue, consider whether there's a more stable API for removing submodules. The approach is acceptable given the constraints but may require maintenance if PyTorch internals change.examples/speculative_decoding/eagle_utils.py (1)
397-415
: Add safety check for empty AR results.If
validate_ar
returns an empty list (e.g., dataset has no samples), the calculationsum(ars) / len(ars)
will raise aZeroDivisionError
.Apply this diff to add a safety check:
print_rank_0("Running AR validation...") ars = validate_ar( model=kwargs["model"], tokenizer=kwargs["processing_class"], ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], device=kwargs["model"].device, ) - print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb: - wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) + if ars: + avg_ar = sum(ars) / len(ars) + print_rank_0(f"Step {state.global_step} AR: {avg_ar:.4f}") + if wandb and wandb.run: + wandb.log({"validate_ar": avg_ar}, step=state.global_step) + else: + print_rank_0(f"Step {state.global_step} AR validation returned no results")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/speculative_decoding/eagle_utils.py
(2 hunks)examples/speculative_decoding/main.py
(2 hunks)modelopt/torch/speculative/plugins/transformers.py
(3 hunks)modelopt/torch/speculative/utils.py
(2 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-29T17:40:37.261Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#391
File: modelopt/torch/speculative/plugins/megatron_eagle.py:1049-1071
Timestamp: 2025-09-29T17:40:37.261Z
Learning: In EAGLE parallel draft implementation in modelopt/torch/speculative/plugins/megatron_eagle.py, overwriting eagle_hidden_states_*_pre_norm variables within the per-draft loop (rather than accumulating all draft pre-norm states) is the expected behavior for parallel draft functionality.
Applied to files:
modelopt/torch/speculative/plugins/transformers.py
📚 Learning: 2025-09-17T14:14:44.961Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.961Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
Applied to files:
modelopt/torch/speculative/plugins/transformers.py
🧬 Code graph analysis (3)
modelopt/torch/speculative/plugins/transformers.py (3)
modelopt/torch/speculative/utils.py (1)
temporary_set_config_value
(369-378)modelopt/onnx/quantization/graph_utils.py (1)
layer_idx
(781-783)modelopt/torch/quantization/utils.py (1)
is_quantized
(239-243)
examples/speculative_decoding/eagle_utils.py (2)
examples/speculative_decoding/ar_validate.py (1)
validate_ar
(29-51)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
examples/speculative_decoding/main.py (2)
examples/speculative_decoding/eagle_utils.py (2)
ARValidationCallback
(397-415)make_eagle_supervised_data_module
(248-325)examples/speculative_decoding/medusa_utils.py (1)
make_medusa_supervised_data_module
(196-219)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (8)
modelopt/torch/speculative/utils.py (1)
368-378
: LGTM! Well-implemented context manager.The
temporary_set_config_value
context manager is correctly implemented with proper validation, exception handling, and guaranteed restoration of the original value via thefinally
block.modelopt/torch/speculative/plugins/transformers.py (3)
918-928
: Good use of the context manager for SDPA enforcement.The temporary config switch to SDPA attention during pseudo_speculative_generate is well-implemented using the new context manager. The PR description mentions this avoids compile errors and improves performance compared to flex_attn.
448-452
: Aux hidden state collection guard is correct. The offline dataset in examples/speculative_decoding/eagle_utils.py provides aux_hidden_states, so gating hook registration withnot eagle_offline
prevents errors and requires no changes.
454-460
: No unintendedis_quantized
attribute usage found. Verified that the only occurrence ofself.is_quantized
is the hack in this plugin; all quantization code uses theis_quantized(module)
function, not this attribute. Forcing it toFalse
solely bypasses the HuggingFace Trainer check as intended.examples/speculative_decoding/main.py (4)
39-41
: LGTM! Import consolidation improves modularity.Moving
ARValidationCallback
toeagle_utils
improves code organization by consolidating related utilities.
142-142
: LGTM! CPU loading prevents initialization issues.Loading the model to CPU first aligns with the fix for partial base-model loading errors mentioned in the PR objectives.
144-148
: LGTM! Preserves layer count for offline training.The logic correctly saves the original
num_hidden_layers
before conversion, which is necessary when layers are removed during offline training initialization.
228-228
: LGTM! Callback integration is correct.The
ARValidationCallback
is properly instantiated with the configured validation interval.
try: | ||
import wandb | ||
|
||
wandb.init() | ||
except ImportError: | ||
wandb = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Premature wandb.init() call in import block.
Calling wandb.init()
at module import time (line 33) is problematic because:
- It initializes wandb even if ARValidationCallback is never used
- It happens before the user can configure wandb settings
- It may conflict with other wandb initialization in the application
Apply this diff to defer initialization:
try:
import wandb
-
- wandb.init()
except ImportError:
wandb = None
Then in the callback, check if wandb is initialized before logging:
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
- if wandb:
+ if wandb and wandb.run:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
try: | |
import wandb | |
wandb.init() | |
except ImportError: | |
wandb = None | |
# At the top of examples/speculative_decoding/eagle_utils.py, adjust the wandb import block: | |
try: | |
import wandb | |
except ImportError: | |
wandb = None | |
# … later in the file, inside ARValidationCallback (around the print_rank_0 call): | |
def on_evaluate(self, args, state, control, metrics=None, logs=None): | |
# existing metric computation… | |
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") | |
if wandb and wandb.run: | |
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) | |
# rest of callback… |
What does this PR do?
Type of change: Bug Fix
Overview:
Fixed 2 bugs when training with quantized base model:
Some base model modules are not loaded but present in _keep_in_fp32_modules , raising an error here when loading model
num_hidden_layer=0
tofrom_pretrained()
to load partial base model, we load the entire base model to cpu first, then deletebase_model.layers
duringconvert()
.The base model is considered "purely quantized models" in HF trainer, therefore cannot be fine-tuned. Error raised here when initializing trainer.
base_model.is_quantized=False
during model conversion;Other minor polishes/fix:
Move
ARValidationCallback
to utils for better clarity ofmain.py
;Removed redundant
trainer._move_model_to_device()
inmain.py
since it will be called during trainer intialization andtrainer.train()
.Use
sdpa
attention duringpseudo_generate
instead offlex_attn
to avoid compile error and improve performance.Usage
Unchanged
# Add a code snippet demonstrating how to use this
Testing
Tested with dummy training on both onlin/offline and quantized/unquantized settings:
Saw training, evaluation and checkpoint saving with not error, and loss decreasing.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
Changes
Refactor