-
Notifications
You must be signed in to change notification settings - Fork 169
TensorQuantizer: remove amax validate forward to reduce cpu-gpu sync; new SVDQuantTensorQuantizer, minor code clean up #357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Warning Rate limit exceeded@realAsma has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 16 minutes and 52 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (4)
WalkthroughAdds rank-aware logging and warnings, per-module AWQLite/AWQClip gating with setup/cleanup hooks and module-name propagation into postprocess; validates TensorQuantizer attributes after calibration using the new validator; removes public svdquant LoRA properties from TensorQuantizer and introduces SVD LoRA support via SVDQuantTensorQuantizer and SVDQuantLinear. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Cal as Calibrator
participant AW as AWQFlow (awq_lite / awq_clip)
participant M as Module(s)
participant L as Logger (rank-0)
User->>Cal: calibrate(model, dataloader)
Cal->>AW: check is_enabled
alt disabled or invalid config
AW-->>L: warn and skip AWQ flow
else
AW->>AW: setup() -- patch module forwards, set per-module flags
Cal->>M: run forward_loop()
AW->>M: postprocess(module, name) for each module
AW->>AW: cleanup() -- restore forwards/attributes
AW-->>L: info AWQ flow complete
end
sequenceDiagram
autonumber
participant In as Input
participant S as SVDQuantLinear
participant Q as SVDQuantTensorQuantizer
participant L as LinearOp
participant Out as Output
In->>S: forward(x)
S->>Q: check _pre_quant_scale and svdquant_lora_a/b
alt LoRA present
S->>S: x_scaled = _apply_pre_quant_scale(x)
S->>S: R = _compute_lora_residual(x_scaled)
S->>L: main linear forward (pre-quant temporarily disabled)
L-->>S: y_main
S->>Out: y = y_main + R
else
S->>L: main linear forward
L-->>Out: y
end
sequenceDiagram
autonumber
participant Cal as calibrate()
participant Mods as Modules
participant TQ as TensorQuantizer
participant L as Logger (rank-0)
Cal->>Mods: run forward_loop()
Cal->>Mods: iterate modules
loop each module with TensorQuantizer
Cal->>TQ: validate_attr(_amax, "amax", warn_error=True, name=mod_name)
Cal->>TQ: validate_attr(_pre_quant_scale, "pre_quant_scale", warn_error=True, name=mod_name)
Cal->>TQ: validate_attr(_bias_value, "bias", warn_error=True, name=mod_name)
end
Cal-->>L: info validation complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/model_calib.py (1)
518-531
: Bug: undefined name ‘module’ inside update_best_paramsShould reference
self
, notmodule
. This crashes postprocess.-def update_best_params(self): - if not module.awq_lite.is_enabled: - return - self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get) - self.awq_lite.best_scale = get_scale( - self.awq_lite.act_scale, - self.awq_lite.weight_scale, - self.awq_lite.best_alpha, - ( - self.parallel_state.tensor_parallel_group - if is_quantized_column_parallel_linear(self) - else None - ), - ) +def update_best_params(self): + if not self.awq_lite.is_enabled: + return + self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get) + self.awq_lite.best_scale = get_scale( + self.awq_lite.act_scale, + self.awq_lite.weight_scale, + self.awq_lite.best_alpha, + ( + self.parallel_state.tensor_parallel_group + if is_quantized_column_parallel_linear(self) + else None + ), + )
🧹 Nitpick comments (3)
modelopt/torch/quantization/model_quant.py (1)
144-145
: Docstring typo: s/quaTruent_cfg/quant_cfgMinor copy fix.
- performs calibration as specified by ``quaTruent_cfg``. + performs calibration as specified by ``quant_cfg``.modelopt/torch/quantization/nn/modules/quant_linear.py (1)
120-123
: Guard _setup against missing weight_quantizerIf
_setup
runs beforeweight_quantizer
exists, this will raise. Guard and fallback to super.def _setup(self): """Overrides and bypass the _setup function.""" - self.weight_quantizer.__class__ = SVDQuantTensorQuantizer + if not hasattr(self, "weight_quantizer"): + super()._setup() + self.weight_quantizer.__class__ = SVDQuantTensorQuantizermodelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
507-524
: Make validate_attr robust to non-tensor numerics and NaNCoerce numerics to tensors; prefer isfinite to catch NaNs. Non-numeric types (e.g., dict) should early-return.
def validate_attr( self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" ): """Validate attribute.""" - attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None) + attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None) if attr_value is None: return True - is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value)) + # Coerce numeric to tensor; skip validation for non-numeric types (e.g., config dicts) + if not isinstance(attr_value, torch.Tensor): + try: + attr_value = torch.as_tensor(attr_value) + except Exception: + return True + is_valid = torch.all(attr_value >= 0) and torch.all(torch.isfinite(attr_value))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/model_calib.py
(13 hunks)modelopt/torch/quantization/model_quant.py
(2 hunks)modelopt/torch/quantization/nn/modules/quant_linear.py
(1 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/model_quant.py (2)
modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
validate_attr
(507-523)
modelopt/torch/quantization/model_calib.py (5)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (8)
axis
(283-285)axis
(288-290)is_enabled
(395-397)forward
(843-943)disable
(399-404)validate_attr
(507-523)amax
(236-241)amax
(244-255)modelopt/torch/utils/network.py (2)
bind_forward_method
(634-653)unpatch_forward_method
(656-660)modelopt/torch/trace/symbols.py (2)
disable
(108-135)named_modules
(444-447)modelopt/torch/quantization/utils.py (1)
enable_weight_access_and_writeback
(424-443)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
TensorQuantizer
(65-1108)modelopt/torch/quantization/nn/modules/quant_module.py (3)
QuantLinearConvBase
(129-169)_setup
(118-126)_setup
(163-169)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (23)
modelopt/torch/quantization/model_calib.py (15)
28-28
: LGTM: rank-aware logging import
print_rank_0
improves multi-rank UX.
372-379
: LGTM: rank-aware warnings in smoothquantNon-blocking modules are cleanly skipped with informative, rank-0 messages.
Also applies to: 389-389
440-443
: LGTM: awq_lite gracefully disabled without forward_loopBetter UX than hard error.
459-472
: LGTM: AWQ-lite setup and gatingForward patching and per-channel axis gating are correct. Input quantizer temporarily disabled is fine.
473-478
: LGTM: Cleanup restores forward and removes calib flagPrevents lingering patched state.
542-557
: LGTM: Cache-mode statistics collectionCorrect handling for DTensor, act scale accumulation, and optional input quantizer stats.
572-587
: LGTM: Attribute validation gate for awq_liteGood use of validate_attr on pre-quant scale and amax to fail fast.
584-587
: LGTM: Disable AWQ-lite on invalid attributesSafe fallback.
604-605
: LGTM: Rank-aware cache notice
626-627
: LGTM: Rank-aware search notice
630-650
: Postprocess depends on fix in update_best_paramsFlow is correct; ensure the
self/module
bug is fixed to populatebest_scale
.
651-664
: LGTM: Guard for no-cache or no-search passes with actionable warningsClear guidance to users to adjust forward_loop.
665-665
: LGTM: Cleanup after AWQ-liteAvoids leaving patched methods or state.
852-852
: LGTM: Rank-aware log in awq_clip
913-913
: LGTM: Rank-aware SVD logmodelopt/torch/quantization/nn/modules/quant_linear.py (3)
65-116
: LGTM: SVDQuantTensorQuantizer properties align with TensorQuantizer patternsSetters follow buffer semantics and shape stability. Consider exporting if intended as public.
If this class is meant to be public, add to all.
147-161
: LGTM: Forward integrates LoRA residual with pre-quant scalingCorrectly disables pre_quant_scale for main path to avoid double scaling.
163-187
: LGTM: Fold weight applies LoRA residual and cleans up buffersAvoids double-applying residual post-fold.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5)
31-35
: LGTM: Centralized ONNX GLOBALS importAvoids runtime import in forward; supports multiple torch versions.
257-263
: LGTM: reset_amax also resets biasKeeps calibration state consistent.
629-631
: LGTM: Explicit ValueError for missing block sizeClearer error than assert.
773-777
: LGTM: Raise on input shape change during block-quantPrevents silent misquantization.
828-829
: LGTM: Validate exported amaxEarly catch for invalid scales.
f8bbfe5
to
555ceea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
modelopt/torch/quantization/model_calib.py (2)
372-372
: Minor: Consider consistent formatting for warning messages.For consistency, consider using
print_rank_0
for warnings as well since you've migrated other print statements.- warnings.warn(f"{name} is not calibrated, skip smoothing") + print_rank_0(f"Warning: {name} is not calibrated, skip smoothing")
375-375
: Minor: Consider consistent warning format.Similar to Line 372, consider using
print_rank_0
for consistency with the distributed printing pattern.Also applies to: 378-378
modelopt/torch/quantization/nn/modules/quant_linear.py (1)
120-122
: Avoid mutating weight_quantizer.class; instantiate/convert instead.Setting
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
bypasses init/serialization/type-checking and can break debugging/typing — replace with a conversion constructor (e.g.SVDQuantTensorQuantizer.from_tensor_quantizer(self.weight_quantizer)
) or perform the transformation in the parent setup. Occurrences: modelopt/torch/quantization/nn/modules/quant_linear.py:122 (this PR); similar patterns found in modelopt/torch/trace/modules/concat.py, modelopt/torch/nas/traced_hp.py, modelopt/torch/opt/dynamic.py — consolidate or document if intentional.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/model_calib.py
(13 hunks)modelopt/torch/quantization/model_quant.py
(2 hunks)modelopt/torch/quantization/nn/modules/quant_linear.py
(1 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/model_quant.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
TensorQuantizer
(65-1107)modelopt/torch/quantization/nn/modules/quant_module.py (3)
QuantLinearConvBase
(129-169)_setup
(118-126)_setup
(163-169)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
is_enabled
(395-397)forward
(843-942)disable
(399-404)validate_attr
(507-523)amax
(236-241)amax
(244-255)modelopt/torch/utils/network.py (2)
bind_forward_method
(634-653)unpatch_forward_method
(656-660)modelopt/torch/quantization/utils.py (1)
enable_weight_access_and_writeback
(424-443)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (24)
modelopt/torch/quantization/model_calib.py (15)
28-28
: LGTM! Improved distributed printing.The switch to rank-aware printing using
print_rank_0
is a good improvement to avoid duplicate output in distributed training environments.
31-34
: LGTM! Proper ONNX GLOBALS compatibility handling.The conditional import properly handles different PyTorch versions where the ONNX GLOBALS location has changed, ensuring compatibility across PyTorch versions. The fallback import path is appropriate.
389-389
: LGTM! Improved distributed output.Converting to rank-aware printing is appropriate for distributed environments.
440-442
: Improved error handling for missing forward_loop.Good defensive programming - gracefully handling missing forward_loop with a warning rather than letting it fail later. This is better than the previous behavior.
459-471
: Enhanced AWQLite lifecycle management.The addition of
setup()
,cleanup()
, andis_enabled
state management improves the robustness of AWQLite. The axis validation on Line 468-470 is a good safeguard against unsupported configurations.
473-477
: Proper cleanup implementation.Good cleanup logic to remove temporary attributes and unpatch forward methods, preventing potential memory leaks or interference.
518-519
: Enhanced validation and error handling.The addition of
is_enabled
checks and early returns for disabled modules prevents unnecessary processing. This is a good defensive programming pattern.Also applies to: 542-542, 584-586
571-587
: Critical improvement: Validation of quantization parameters.This is a significant improvement that validates
_pre_quant_scale
and_amax
attributes before proceeding with AWQ processing. The validation prevents processing with invalid quantization parameters that could lead to incorrect results.
604-604
: LGTM! Enhanced progress reporting.The switch to rank-aware logging for progress updates is appropriate for distributed training.
Also applies to: 626-626
630-630
: Updated function signature for enhanced logging.The addition of
name
parameter topostprocess
enables better context-aware logging and error reporting.
636-636
: Improved conditional processing and warning messages.The conditional check for
is_enabled
before applying pre-quantization scaling (Line 645) and the informative warning messages provide better user feedback about disabled modules.Also applies to: 645-650
653-662
: Enhanced validation for AWQLite execution.The additional validation checks for
num_cache_steps
andnum_search_steps
help identify configuration issues early, providing clear feedback about invalid forward_loop functions.
665-665
: Proper cleanup call.The explicit
cleanup()
call ensures proper resource management and prevents potential issues.
852-852
: LGTM! Consistent rank-aware logging.Converting AWQClip logging to use rank-aware printing for consistency with other distributed operations.
913-913
: LGTM! Enhanced SVDQuant logging.Adding rank-aware printing with the module name provides better context for SVDQuant operations in distributed environments.
modelopt/torch/quantization/nn/modules/quant_linear.py (4)
65-115
: Well-implemented SVDQuantTensorQuantizer class.The new
SVDQuantTensorQuantizer
class properly extendsTensorQuantizer
with LoRA support. The property implementation follows the same pattern as the parent class with proper buffer registration, shape validation, and device handling.
129-146
: LGTM! Clean helper method implementations.Both
_apply_pre_quant_scale
and_compute_lora_residual
are well-implemented with proper conditional checks and clear logic. The use of_not_sequential_quantizers()
provides consistent validation.
147-161
: Enhanced forward method with LoRA support.The forward method cleanly handles the SVDQuant LoRA case by:
- Detecting LoRA presence
- Applying pre-quantization scaling
- Computing LoRA residuals
- Disabling pre-quant scaling for the main forward pass
- Adding residuals to the final output
The logic is sound and maintains proper separation of concerns.
163-186
: Proper weight folding with cleanup.The
fold_weight
implementation correctly:
- Calls the parent fold_weight method
- Adds the LoRA residual to the weight when folding
- Cleans up the LoRA attributes after folding
This ensures the folded weight contains the complete transformation.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5)
31-34
: LGTM! Improved ONNX compatibility.The conditional import properly handles different PyTorch versions and ensures GLOBALS is available at the module level, which should resolve TorchDynamo compilation issues mentioned in the PR objectives.
507-523
: Excellent addition: Centralized attribute validation.The new
validate_attr
method provides a clean, centralized way to validate tensor attributes with configurable error handling (warnings vs exceptions). This is much better than scattered validation logic.
262-262
: Enhanced reset functionality.Adding
self.reset_bias()
toreset_amax()
ensures complete quantizer state reset, which is more thorough than just resetting amax.
629-630
: Improved error handling with ValueError.Converting runtime assertions to proper
ValueError
exceptions improves error clarity and follows Python best practices for input validation.Also applies to: 773-777
828-828
: LGTM! Using centralized validation.Replacing the old
_validate_amax
with the newvalidate_attr
method centralizes validation logic and improves maintainability.
555ceea
to
a13c5d4
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #357 +/- ##
==========================================
+ Coverage 73.77% 73.86% +0.08%
==========================================
Files 171 171
Lines 17619 17629 +10
==========================================
+ Hits 12998 13021 +23
+ Misses 4621 4608 -13 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (6)
modelopt/torch/quantization/model_quant.py (1)
108-112
: Avoid false positives: don’t validate_bias_value
with generic “>= 0”.
validate_attr
uses a non-negative check that’s inappropriate for bias (bias can be negative). This will spam warnings and may mislead downstream gating.Options:
- Prefer: make
validate_attr
attribute-aware (see suggested change in tensor_quantizer.py to allow any finite value forbias_value
and strictly positive forpre_quant_scale
).- Or, stop validating
_bias_value
here and keep only["_amax", "_pre_quant_scale"]
.modelopt/torch/quantization/nn/modules/quant_linear.py (2)
120-123
: _setup may run beforeweight_quantizer
exists; guard and/or call super._setup().If
_setup
is invoked on a fresh instance,self.weight_quantizer
may not be initialized yet. Guard it to avoid AttributeError.Apply:
- def _setup(self): - """Overrides and bypass the _setup function.""" - self.weight_quantizer.__class__ = SVDQuantTensorQuantizer + def _setup(self): + """Ensure weight_quantizer exists and install SVDQuantTensorQuantizer.""" + if not hasattr(self, "weight_quantizer"): + try: + super()._setup() + except Exception: + pass + if hasattr(self, "weight_quantizer"): + self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
136-145
: Robust access to optional LoRA attrs; avoid AttributeError if quantizer class isn’t swapped.Access
svdquant_lora_a/b
viagetattr(..., default=None)
.Apply:
- if ( - self._not_sequential_quantizers() - and self.weight_quantizer.svdquant_lora_a is not None - and self.weight_quantizer.svdquant_lora_b is not None - ): + lora_a = getattr(self.weight_quantizer, "svdquant_lora_a", None) + lora_b = getattr(self.weight_quantizer, "svdquant_lora_b", None) + if self._not_sequential_quantizers() and lora_a is not None and lora_b is not None: - lora_a = F.linear(input, weight=self.weight_quantizer.svdquant_lora_a) - lora_b = F.linear(lora_a, weight=self.weight_quantizer.svdquant_lora_b) - return lora_b + lora_a_out = F.linear(input, weight=lora_a) + lora_b_out = F.linear(lora_a_out, weight=lora_b) + return lora_b_out return None @@ - has_svdquant_lora = ( - self._not_sequential_quantizers() - and self.weight_quantizer.svdquant_lora_a is not None - and self.weight_quantizer.svdquant_lora_b is not None - ) + lora_a = getattr(self.weight_quantizer, "svdquant_lora_a", None) + lora_b = getattr(self.weight_quantizer, "svdquant_lora_b", None) + has_svdquant_lora = self._not_sequential_quantizers() and lora_a is not None and lora_b is not None @@ - if ( - self._not_sequential_quantizers() - and self.weight_quantizer.svdquant_lora_a is not None - and self.weight_quantizer.svdquant_lora_b is not None - ): - self.weight.data.copy_( - self.weight - + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a - ) + lora_a = getattr(self.weight_quantizer, "svdquant_lora_a", None) + lora_b = getattr(self.weight_quantizer, "svdquant_lora_b", None) + if self._not_sequential_quantizers() and lora_a is not None and lora_b is not None: + self.weight.data.copy_(self.weight + (lora_b @ lora_a))Also applies to: 147-161, 166-179
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
828-828
: OK to validate export amax; consider clone to avoid mutating buffer.
export_amax
editsamax
in place before validation. If side effects are undesirable, clone before edits.Apply:
- if not hasattr(self, "_amax_shape_for_export"): - amax = self.amax + if not hasattr(self, "_amax_shape_for_export"): + amax = self.amax.clone()And keep subsequent ops on this clone.
modelopt/torch/quantization/model_calib.py (2)
574-589
: AWQ-lite gating: rely on attribute-aware validate_attr.This block assumes “>= 0” semantics; with the proposed validate_attr fix, pre_quant_scale will be strictly positive and amax non-negative. Keep as-is after that change.
633-667
: Postprocess robustness: guard when best_scale is None.If no search steps ran (or disabled), avoid applying a None scale.
Apply:
- if module.awq_lite.is_enabled: + if module.awq_lite.is_enabled and module.awq_lite.best_scale is not None: apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/model_calib.py
(13 hunks)modelopt/torch/quantization/model_quant.py
(2 hunks)modelopt/torch/quantization/nn/modules/quant_linear.py
(1 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/model_quant.py (2)
modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
validate_attr
(507-523)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
is_enabled
(395-397)forward
(843-942)disable
(399-404)validate_attr
(507-523)amax
(236-241)amax
(244-255)modelopt/torch/utils/network.py (2)
bind_forward_method
(634-653)unpatch_forward_method
(656-660)modelopt/torch/quantization/utils.py (1)
enable_weight_access_and_writeback
(424-443)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
TensorQuantizer
(65-1107)modelopt/torch/quantization/nn/modules/quant_module.py (3)
QuantLinearConvBase
(129-169)_setup
(118-126)_setup
(163-169)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
629-631
: Good: raise ValueError for missing block size.Switching from assert to explicit ValueError improves error clarity and survivability.
773-777
: Good: explicit shape validation for block-quant path.Raising ValueError on shape drift is appropriate here.
modelopt/torch/quantization/model_calib.py (2)
440-443
: Graceful AWQ-lite skip when forward_loop is None.Good change to disable instead of erroring.
855-856
: Rank-aware logging: LGTM.Swapping to print_rank_0 avoids duplicate logs across ranks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
257-263
: reset_amax now resets bias too — clarify API contractThis changes behavior; make it explicit in the docstring (or split into a separate reset if needed).
- def reset_amax(self): - """Reset amax to None.""" + def reset_amax(self): + """Reset amax to None and reset bias state/calibrator."""
629-631
: Good: convert assert to ValueError; consider richer context in the messageApproving the change. Optional: include dim/key context to speed debugging.
- raise ValueError("block size for dynamic quantization not found.") + raise ValueError( + f"Block size for dynamic quantization not found " + f"(inputs.dim={inputs.dim()}, keys={list(self.block_sizes.keys())})." + )
828-829
: Pass semantic name 'amax' (not '_amax') to validatorIf validator keys by attr semantics, prefer 'amax'. If you keep leading underscores, ensure the validator strips them (as suggested).
- self.validate_attr(attr_name="_amax", attr_value=amax) + self.validate_attr(attr_name="amax", attr_value=amax)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
773-777
: Good guard against input shape driftClear error with explicit shapes. This will prevent hard-to-trace silent mis-quantization.
507-524
: Make validate_attr attribute‑aware and avoid logging full tensorsThe generic ">= 0 and finite" check is incorrect for several attrs (e.g., bias_value can be negative; pre_quant_scale should be > 0). Also avoid dumping large tensors in messages. This was raised previously.
- def validate_attr( - self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" - ): - """Validate attribute.""" - attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None) - if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta): - return True - is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value)) - if is_valid: - return True - name = f"{name} " if name else "" - msg = f"{name}{attr_name} contains invalid values: {attr_value}" - if warn_error: - warnings.warn(msg) - if raise_error: - raise ValueError(msg) - return False + def validate_attr( + self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" + ): + """Validate attribute value based on attribute semantics. + - amax: finite and >= 0 + - pre_quant_scale: finite and > 0 + - bias_value: finite (sign unconstrained) + - default: finite + """ + value = attr_value if attr_value is not None else getattr(self, attr_name, None) + if value is None or (isinstance(value, torch.Tensor) and value.is_meta): + return True + + key = attr_name.lstrip("_") + is_finite = torch.isfinite(value).all() + if key in ("amax",): + is_valid = bool(is_finite and (value >= 0).all()) + elif key in ("pre_quant_scale",): + is_valid = bool(is_finite and (value > 0).all()) + elif key in ("bias_value",): + is_valid = bool(is_finite) + else: + is_valid = bool(is_finite) + + if is_valid: + return True + + # Compact message to avoid dumping large tensors + name_prefix = f"{name} " if name else "" + shape = tuple(value.shape) if hasattr(value, "shape") else () + any_nan = bool(torch.isnan(value).any()) + any_inf = bool(torch.isinf(value).any()) + msg = ( + f"{name_prefix}{attr_name} invalid (shape={shape}, any_nan={any_nan}, any_inf={any_inf})." + ) + if warn_error: + warnings.warn(msg) + if raise_error: + raise ValueError(msg) + return False
a548dea
to
2f7b2eb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (6)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
65-115
: Unify LoRA buffer setters via _set_buffer; keep device/dtype; reduce code duplicationSetters duplicate buffer-registration logic and miss a consistent to(device/dtype) path.
Apply:
@svdquant_lora_a.setter def svdquant_lora_a(self, value): """Lora a weights for svdquant.""" assert value is not None, "svdquant_lora_a cannot be set to None." - - if not isinstance(value, torch.Tensor): - value = torch.tensor(value) - - if not hasattr(self, "_svdquant_lora_a"): - self.register_buffer("_svdquant_lora_a", value.clone().detach()) - else: - if self._svdquant_lora_a.shape != value.shape: - raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.") - self._svdquant_lora_a.data.copy_( - value.clone().detach().to(self._svdquant_lora_a.device) - ) + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + if hasattr(self, "_svdquant_lora_a") and self._svdquant_lora_a.shape != value.shape: + raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.") + value = value.clone().detach().to( + getattr(self, "_svdquant_lora_a", value).device if hasattr(self, "_svdquant_lora_a") else value.device + ) + self._set_buffer("_svdquant_lora_a", value) @@ @svdquant_lora_b.setter def svdquant_lora_b(self, value): """Lora b weights for svdquant.""" assert value is not None, "svdquant_lora_b cannot be set to None." - - if not isinstance(value, torch.Tensor): - value = torch.tensor(value) - - if not hasattr(self, "_svdquant_lora_b"): - self.register_buffer("_svdquant_lora_b", value.clone().detach()) - else: - if self._svdquant_lora_b.shape != value.shape: - raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.") - self._svdquant_lora_b.data.copy_( - value.clone().detach().to(self._svdquant_lora_b.device) - ) + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + if hasattr(self, "_svdquant_lora_b") and self._svdquant_lora_b.shape != value.shape: + raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.") + value = value.clone().detach().to( + getattr(self, "_svdquant_lora_b", value).device if hasattr(self, "_svdquant_lora_b") else value.device + ) + self._set_buffer("_svdquant_lora_b", value)
176-179
: Avoid extra allocation when folding; use in-place add_Current code materializes weight + residual then copies. Add in-place is cheaper.
Apply:
- self.weight.data.copy_( - self.weight - + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a - ) + self.weight.data.add_(self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a)modelopt/torch/quantization/model_calib.py (3)
574-585
: Surface validation context and NaN/Inf reasons in logsPass name to validate_attr and enable warn_error for better diagnostics; you already warn, but validate_attr can emit attribute-aware context too once improved.
Apply:
- for tq in [self.input_quantizer, self.weight_quantizer]: - for attr in ["_pre_quant_scale", "_amax"]: - if not tq.validate_attr(attr_name=attr): + for tq in [self.input_quantizer, self.weight_quantizer]: + for attr in ["_pre_quant_scale", "_amax"]: + if not tq.validate_attr(attr_name=attr, warn_error=True, name=name): disable_awq = True warnings.warn( f"awq_lite: {attr} is not valid for {self.awq_lite.name}, skipping awq_lite" ) break
633-653
: Postprocess: guard against missing best_scale when is_enabled toggles lateIf is_enabled flipped off during search, best_scale can remain None. Bail out to max_calibrate in that case.
Apply:
- if module.awq_lite.is_enabled: - apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale) + if module.awq_lite.is_enabled and module.awq_lite.best_scale is not None: + apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale) + elif module.awq_lite.is_enabled: + warnings.warn(f"awq_lite: No best_scale for {name}, quantizing with max calibration.") + max_calibrate(module, lambda module: module.weight_quantizer(module.weight))
656-663
: No-op modules in second forward_loop: demote to rank-0 logsThese user-facing warnings will repeat on each rank. Prefer print_rank_0 for consistency.
Apply:
- warnings.warn( - "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the" - f" {name}. Please provide a valid `forward_loop` function that can be used to" - " forward data through the model many times." - ) + print_rank_0( + "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the" + f" {name}. Please provide a valid `forward_loop` function that can be used to" + " forward data through the model many times." + )modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
828-828
: Consider warn_error=True when validating export amaxSurfacing invalid amax (shape/NaN/Inf) during export helps early detection.
Apply:
- self.validate_attr(attr_name="_amax", attr_value=amax) + self.validate_attr(attr_name="_amax", attr_value=amax, warn_error=True)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/model_calib.py
(13 hunks)modelopt/torch/quantization/model_quant.py
(1 hunks)modelopt/torch/quantization/nn/modules/quant_linear.py
(1 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/model_quant.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
TensorQuantizer
(65-1107)modelopt/torch/quantization/nn/modules/quant_module.py (3)
QuantLinearConvBase
(129-169)_setup
(118-126)_setup
(163-169)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (8)
axis
(283-285)axis
(288-290)is_enabled
(395-397)forward
(843-942)disable
(399-404)validate_attr
(507-523)amax
(236-241)amax
(244-255)modelopt/torch/utils/network.py (2)
bind_forward_method
(634-653)unpatch_forward_method
(656-660)modelopt/torch/quantization/utils.py (1)
enable_weight_access_and_writeback
(424-443)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
31-35
: Harden GLOBALS import across PyTorch versions; add safe fallbackPrivate import paths change across versions; current hasattr check can still fail. Use try/except and a dummy fallback to avoid import-time crashes.
Apply:
-if hasattr(torch.onnx, "_globals"): - from torch.onnx._globals import GLOBALS -else: # torch >= 2.9 - from torch.onnx._internal.torchscript_exporter._globals import GLOBALS +try: + from torch.onnx._globals import GLOBALS # older PyTorch +except Exception: + try: + from torch.onnx._internal.torchscript_exporter._globals import GLOBALS # PyTorch >= 2.9 + except Exception: + class _DummyGlobals: + in_onnx_export = False + GLOBALS = _DummyGlobals()
507-524
: Make validate_attr attribute-aware; enforce pre_quant_scale > 0; avoid dumping large tensorsCurrent rule (>=0 and not inf) is wrong for pre_quant_scale (must be > 0) and too generic. Also avoid logging entire tensors.
Apply:
- def validate_attr( - self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" - ): - """Validate attribute.""" - attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None) - if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta): - return True - is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value)) - if is_valid: - return True - name = f"{name} " if name else "" - msg = f"{name}{attr_name} contains invalid values: {attr_value}" - if warn_error: - warnings.warn(msg) - if raise_error: - raise ValueError(msg) - return False + def validate_attr( + self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" + ): + """Validate attribute value based on attribute semantics.""" + value = attr_value if attr_value is not None else getattr(self, attr_name, None) + if value is None or (isinstance(value, torch.Tensor) and getattr(value, "is_meta", False)): + return True + key = attr_name.lstrip("_") + is_finite = torch.isfinite(value).all() if isinstance(value, torch.Tensor) else math.isfinite(float(value)) + if key == "amax": + is_valid = bool(is_finite and (value >= 0).all()) + elif key == "pre_quant_scale": + is_valid = bool(is_finite and (value > 0).all()) + elif key in ("bias", "bias_value"): + is_valid = bool(is_finite) + else: + is_valid = bool(is_finite) + if is_valid: + return True + name_prefix = f"{name} " if name else "" + shape = tuple(value.shape) if hasattr(value, "shape") else () + any_nan = bool(torch.isnan(value).any()) if isinstance(value, torch.Tensor) else False + any_inf = bool(torch.isinf(value).any()) if isinstance(value, torch.Tensor) else not is_finite + msg = f"{name_prefix}{attr_name} invalid (shape={shape}, any_nan={any_nan}, any_inf={any_inf})." + if warn_error: + warnings.warn(msg) + if raise_error: + raise ValueError(msg) + return False
def _setup(self): | ||
"""Overrides and bypass the _setup function.""" | ||
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Call super()._setup() before reclassing the quantizer to avoid missing attributes
If _setup bypasses the base implementation, weight_quantizer may never be created, leading to AttributeError. Initialize via super() and then reclass.
Apply:
- def _setup(self):
- """Overrides and bypass the _setup function."""
- self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
+ def _setup(self):
+ """Override setup and upgrade the weight quantizer to SVDQuantTensorQuantizer."""
+ super()._setup()
+ if not isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
+ # Safe in CPython for pure-Python classes with identical layout
+ self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _setup(self): | |
"""Overrides and bypass the _setup function.""" | |
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer | |
def _setup(self): | |
"""Override setup and upgrade the weight quantizer to SVDQuantTensorQuantizer.""" | |
super()._setup() | |
if not isinstance(self.weight_quantizer, SVDQuantTensorQuantizer): | |
# Safe in CPython for pure-Python classes with identical layout | |
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/quant_linear.py around lines 120 to
123, the _setup override currently reassigns weight_quantizer's class before the
base class has run, which can leave weight_quantizer uninitialized and raise
AttributeError; fix by calling super()._setup() first to allow the base
implementation to create and initialize weight_quantizer, then reclass
weight_quantizer to SVDQuantTensorQuantizer, preserving the original instance
and any required attributes.
2f7b2eb
to
a72b7ea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
modelopt/torch/quantization/model_quant.py (1)
108-112
: Also validate bias_value during post-calib checks (after fixing validate_attr semantics).Add "_bias_value" to the attributes validated post-calibration to catch invalid bias tensors. Note: ensure TensorQuantizer.validate_attr treats bias semantics correctly first (bias can be negative). Based on learnings
- for name, module in model.named_modules(): - if isinstance(module, TensorQuantizer): - for attr_name in ["_amax", "_pre_quant_scale"]: - module.validate_attr(attr_name=attr_name, warn_error=True, name=name) + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer): + for attr_name in ["_amax", "_bias_value", "_pre_quant_scale"]: + module.validate_attr(attr_name=attr_name, warn_error=True, name=name)modelopt/torch/quantization/nn/modules/quant_linear.py (1)
65-115
: Preserve dtype on LoRA buffer updates.When copying into existing buffers, match both device and dtype to avoid silent dtype flips across setters.
- self._svdquant_lora_a.data.copy_( - value.clone().detach().to(self._svdquant_lora_a.device) - ) + self._svdquant_lora_a.data.copy_( + value.clone().detach().to(device=self._svdquant_lora_a.device, dtype=self._svdquant_lora_a.dtype) + ) @@ - self._svdquant_lora_b.data.copy_( - value.clone().detach().to(self._svdquant_lora_b.device) - ) + self._svdquant_lora_b.data.copy_( + value.clone().detach().to(device=self._svdquant_lora_b.device, dtype=self._svdquant_lora_b.dtype) + )modelopt/torch/quantization/model_calib.py (1)
575-586
: Strengthen validation: enforce pre_quant_scale > 0 and catch NaN.The disable-awq gate relies on TensorQuantizer.validate_attr. Current impl only checks “>=0 and not inf”; it misses NaN and allows zeros for pre_quant_scale. Tighten validate_attr semantics (see separate comment) so this gate reliably disables AWQ when scales are invalid.
After updating validate_attr, no code changes needed here; re-run a small calib pass to confirm AWQ isn’t disabled spuriously and that zero/NaN scales trigger the intended warning path.
Also applies to: 587-590
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
507-523
: Make validate_attr attribute-aware; add NaN checks and compact logs.Current “>=0 and not inf” rule is incorrect for pre_quant_scale (should be strictly > 0) and too strict for bias (can be negative). Also misses NaN. Implement per-attribute semantics and avoid logging full tensors.
- def validate_attr( - self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" - ): - """Validate attribute.""" - attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None) - if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta): - return True - is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value)) - if is_valid: - return True - name = f"{name} " if name else "" - msg = f"{name}{attr_name} contains invalid values: {attr_value}" - if warn_error: - warnings.warn(msg) - if raise_error: - raise ValueError(msg) - return False + def validate_attr( + self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" + ): + """Validate attribute value based on attribute semantics. + - amax: finite and >= 0 + - pre_quant_scale: finite and > 0 + - bias_value: finite (sign unconstrained) + - default: finite + """ + value = attr_value if attr_value is not None else getattr(self, attr_name, None) + if value is None or (isinstance(value, torch.Tensor) and value.is_meta): + return True + is_finite = torch.isfinite(value).all() + key = attr_name.lstrip("_") + if key in ("amax",): + ok = bool(is_finite and (value >= 0).all()) + elif key in ("pre_quant_scale",): + ok = bool(is_finite and (value > 0).all()) + elif key in ("bias_value",): + ok = bool(is_finite) + else: + ok = bool(is_finite) + if ok: + return True + # Compact diagnostics + name_prefix = f"{name} " if name else "" + shape = tuple(value.shape) if hasattr(value, "shape") else () + any_nan = bool(torch.isnan(value).any()) + any_inf = bool(torch.isinf(value).any()) + msg = f"{name_prefix}{attr_name} invalid (shape={shape}, any_nan={any_nan}, any_inf={any_inf})." + if warn_error: + warnings.warn(msg) + if raise_error: + raise ValueError(msg) + return False
828-828
: Consider warning on invalid exported amax.Optionally pass warn_error=True to surface issues during export formatting.
- self.validate_attr(attr_name="_amax", attr_value=amax) + self.validate_attr(attr_name="_amax", attr_value=amax, warn_error=True, name=self.__class__.__name__)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/model_calib.py
(13 hunks)modelopt/torch/quantization/model_quant.py
(1 hunks)modelopt/torch/quantization/nn/modules/quant_linear.py
(1 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T16:40:12.799Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.799Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.
Applied to files:
modelopt/torch/quantization/model_quant.py
🧬 Code graph analysis (2)
modelopt/torch/quantization/model_quant.py (2)
modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
validate_attr
(507-523)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
TensorQuantizer
(65-1107)modelopt/torch/quantization/nn/modules/quant_module.py (3)
QuantLinearConvBase
(129-169)_setup
(118-126)_setup
(163-169)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (6)
modelopt/torch/quantization/model_calib.py (2)
440-443
: Good guard: skip AWQ Lite when forward_loop is missing.Prevents accidental errors and aligns with “disable instead of error” objective.
609-609
: Rank-aware messaging looks good.Using print_rank_0 reduces log spam under DDP.
Also applies to: 631-631, 917-917
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
31-35
: ONNX GLOBALS import fallback (prior suggestion).A try/except import with a safe dummy fallback is more robust across PyTorch versions.
630-631
: Prefer ValueError over assert for missing block size.Good change; avoids stripping by -O and provides clearer error semantics.
773-777
: Clear error on dynamic shape change in block-quant path.Good guard; prevents silent mis-quantization.
modelopt/torch/quantization/nn/modules/quant_linear.py (1)
120-125
: Bug: _setup bypasses base init; weight_quantizer may be undefined.Accessing self.weight_quantizer before super()._setup() risks AttributeError. Initialize via base and then reclass.
- def _setup(self): - """Overrides and bypass the _setup function.""" - if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer): - return - self.weight_quantizer.__class__ = SVDQuantTensorQuantizer + def _setup(self): + """Override setup and upgrade the weight quantizer to SVDQuantTensorQuantizer.""" + super()._setup() + if hasattr(self, "weight_quantizer") and not isinstance( + self.weight_quantizer, SVDQuantTensorQuantizer + ): + # Safe for pure-Python classes with identical layout + self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/model_calib.py
(13 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
forward
(843-942)validate_attr
(507-523)amax
(236-241)amax
(244-255)modelopt/torch/utils/network.py (2)
bind_forward_method
(634-653)unpatch_forward_method
(656-660)modelopt/torch/quantization/utils.py (1)
enable_weight_access_and_writeback
(424-443)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
507-523
: Makevalidate_attr
attribute-aware (bias can be negative, pre-quant scale must be > 0).This helper now treats every attribute as “finite and ≥ 0”. That’s fine for
amax
, but it incorrectly rejects perfectly valid negativebias_value
(bias calibration routinely produces signed values) and it letspre_quant_scale
slip through at zero. As soon aswarn_error
orraise_error
is enabled for those attrs we’ll surface false alarms/break flows. Please reinstate per-attribute rules.def validate_attr( self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name="" ): """Validate attribute.""" - attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None) - if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta): - return True - is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value)) - if is_valid: + value = attr_value if attr_value is not None else getattr(self, attr_name, None) + if value is None or (isinstance(value, torch.Tensor) and value.is_meta): + return True + + tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value) + key = attr_name.lstrip("_") + + if tensor.is_floating_point() or tensor.is_complex(): + finite_mask = torch.isfinite(tensor) + else: + finite_mask = torch.ones_like(tensor, dtype=torch.bool) + + if key in {"amax"}: + valid_mask = finite_mask & (tensor >= 0) + elif key in {"pre_quant_scale"}: + valid_mask = finite_mask & (tensor > 0) + elif key in {"bias", "bias_value"}: + valid_mask = finite_mask + else: + valid_mask = finite_mask + + if valid_mask.all(): return True name = f"{name}." if name else "" - msg = f"{name}{attr_name} contains invalid values: {attr_value}" + any_nan = bool(tensor.is_floating_point() and torch.isnan(tensor).any()) + any_inf = bool(tensor.is_floating_point() and torch.isinf(tensor).any()) + shape = tuple(tensor.shape) + msg = ( + f"{name}{attr_name} contains invalid values " + f"(shape={shape}, any_nan={any_nan}, any_inf={any_inf})" + ) if warn_error: warnings.warn(msg) if raise_error: raise ValueError(msg) return False
…u sync, new SVDQuantTensorQuantizer, minor code clean up Signed-off-by: realAsma <[email protected]>
08c2628
to
b8e5ad2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… new SVDQuantTensorQuantizer, minor code clean up (#357) Signed-off-by: realAsma <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: ? clean up PR
Overview:
svdquant
specific attributes toSVDQuantTensorQuantizer
Usage
No user facing changes;
Testing
Ran with unit tests and
hf_ptq.py
int4_awqSummary by CodeRabbit
New Features
Improvements
Bug Fixes