-
Notifications
You must be signed in to change notification settings - Fork 169
Fix: supporting gpt-oss HF eagle #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: h-guo18 <[email protected]>
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughAdds an eagle-only speculative-decoding export path with new spec_opt_only gating and dedicated exporters for state_dict and config; export_hf_checkpoint now early-returns for speculative-only models writing Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as Caller
participant E as export_hf_checkpoint
participant P as hf_spec_export (plugin)
participant S as safetensors.save_file
participant W as _export_hf_checkpoint
U->>E: export_hf_checkpoint(model, out_dir)
E->>P: spec_opt_only(model)
alt Speculative-only (eagle)
E->>P: export_spec_ckpt_state_dict(model)
P-->>E: state_dict
E->>S: save_file(state_dict, "out_dir/model.safetensors")
E->>P: export_spec_ckpt_config(model)
P-->>E: config_json
E->>E: write "out_dir/config.json"
E-->>U: return (early exit)
else Non-speculative or mixed
E->>W: _export_hf_checkpoint(model, out_dir)
W-->>E: standard artifacts
E-->>U: return
end
note over E,P: Prior rename/prune and config-adjust hooks removed from normal path
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: h-guo18 <[email protected]>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #398 +/- ##
=======================================
Coverage 73.79% 73.79%
=======================================
Files 171 171
Lines 17591 17591
=======================================
Hits 12982 12982
Misses 4609 4609 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
modelopt/torch/export/unified_export_hf.py (1)
512-518
: Consider honoring save_modelopt_state parameter and adding error handling.The early-exit path correctly prevents errors with offline training checkpoints, but consider these improvements:
- The
save_modelopt_state
parameter (line 499) is unused in this path. If users request modelopt state preservation, should it be saved separately?- File write operations lack error handling, unlike the try-except block in the standard export path (lines 520-550).
- Consider using
Path
operations for consistency:export_dir / "model.safetensors"
instead of f-strings.Optional refactor to use Path operations:
if spec_opt_only(model): - save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") - with open(f"{export_dir}/config.json", "w") as file: + save_file(export_spec_ckpt_state_dict(model), export_dir / "model.safetensors") + with open(export_dir / "config.json", "w") as file: json.dump(export_spec_ckpt_config(model), file, indent=4) returnmodelopt/torch/export/plugins/hf_spec_export.py (3)
77-79
: Track the TODO for cleaner lm_head.weight handling.The temporary fix for handling missing
eagle_lm_head.weight
works but should be addressed. The fallback tomodel.state_dict()["lm_head.weight"]
could fail if the key doesn't exist in the base model either.Do you want me to open a new issue to track this technical debt?
141-141
: Fix typo in comment.Minor typo: "load fron eagle config" should be "load from eagle config".
- # First, we try to load fron eagle config. + # First, we try to load from eagle config.
62-62
: Consider more descriptive assertion message.The assertion message "Not purely eagle model." could be more helpful for debugging. Consider providing information about what optimization modes were found.
- assert spec_opt_only(model), "Not purely eagle model." + opt_modes = getattr(model, "_modelopt_state", None) + assert spec_opt_only(model), ( + f"Expected purely eagle model but found optimization modes: {opt_modes}. " + "This export path only supports models with a single 'eagle' optimization." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/speculative_decoding/README.md
(1 hunks)modelopt/torch/export/plugins/hf_spec_export.py
(2 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/speculative/eagle/default_config.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/export/plugins/hf_spec_export.py (3)
export_spec_ckpt_config
(84-148)export_spec_ckpt_state_dict
(59-81)spec_opt_only
(51-56)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/torch/speculative/eagle/default_config.py (1)
50-50
: Verifyhead_dim
in eagle default_config
In modelopt/torch/speculative/eagle/default_config.py (line 50),head_dim
is set to 64. Confirm that this matcheshidden_size / num_attention_heads
in the same file (or document why it intentionally differs) to avoid silent mis-inference.
Signed-off-by: h-guo18 <[email protected]>
What does this PR do?
Type of change: Bug fix
Overview:
This PR contains two minor fix to support gpt-oss eagle training:
head_dim
in default eagle config to prevent Llama inferring the head_dim byhidden_size/num_heads
. This leads to wrong head dim for models like GPT-oss, wherehidden_size != num_heads * head_dim
._export_hf_checkpoint
, which triggers error for offline training checkpoints.Usage
Not changed.
Testing
Tested with gpt-oss-120b with offline training, export, and tested checkpoint on spec-bench.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Refactor
Config