-
Notifications
You must be signed in to change notification settings - Fork 169
Avoid squeezing original weight tensors with leading dim 1 #294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Chenjie Luo <[email protected]>
Caution Review failedAn error occurred during the review process. Please try again later. ✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
post_state_dict[prefix + new_suffix] = value | ||
break | ||
|
||
# Squeeze tensors with a leading dimension of 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you specify what errors you get for phi4-mm?
We needed this for sticking with our ckpt format and I remember if we don’t squeeze it will break some deployment flow with TRT-LLM.
It works for all our supported models so far.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #294 +/- ##
=======================================
Coverage 73.93% 73.93%
=======================================
Files 172 172
Lines 17408 17408
=======================================
Hits 12870 12870
Misses 4538 4538 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Chenjie Luo <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/export/quant_utils.py (1)
872-881
: Narrow the key match and avoid mutating during iteration."scale" substring can overmatch (e.g., keys containing "rescaled"). Safer to target leaf names ending with "scale" (e.g., k_scale, v_scale, pre_quant_scale) and iterate over a snapshot for clarity.
Apply:
- # Squeeze scales with a leading dimension of 1 - for key, value in post_state_dict.items(): - if ( - "scale" in key - and isinstance(value, torch.Tensor) - and value.dim() == 3 - and value.shape[0] == 1 - ): - post_state_dict[key] = value.squeeze(0) + # Only squeeze scale tensors (not weights) when the leading expert dim is 1. + for key, value in list(post_state_dict.items()): + if ( + isinstance(value, torch.Tensor) + and value.dim() == 3 + and value.shape[0] == 1 + and key.rsplit(".", 1)[-1].endswith("scale") + ): + post_state_dict[key] = value.squeeze(0)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
modelopt/torch/export/quant_utils.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/export/quant_utils.py (2)
872-881
: Good fix: stop squeezing weights; only squeeze scale tensors with leading dim 1.This aligns with the Phi-4-MM report and should avoid shape mismatches on weights while keeping exported scale shapes sane.
872-881
: Add regression test for squeeze behavior of scales vs weights. Ensure that weight tensors (e.g., shape (1,4,8)) retain their leading dimension while scale tensors (e.g., shape (1,4,1)) are squeezed to (4,1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline - We should whitelist the scaling factor tensors for squeezing rather than remove it.
Signed-off-by: Chenjie Luo <[email protected]>
Signed-off-by: Chenjie Luo <[email protected]>
What does this PR do?
bug fix
Overview: ?
Phi4MM has weights with leading dim 1. Squeezing it will yield to shape mismatch with the original weight.
Testing
scripts/huggingface_example.sh --model /models/microsoft/Phi-4-multimodal-instruct/ --quant fp8 --export_fmt hf --trust_remote_code
and check the weights shape
Summary by CodeRabbit