Skip to content

Conversation

realAsma
Copy link
Contributor

@realAsma realAsma commented Sep 23, 2025

What does this PR do?

Type of change: ? clean up PR

Overview:

  1. remove amax validate from forward to reduce cpu-gpu sync. This improved auto-quantize score estimation time by about 10%
  2. Moved svdquant specific attributes to SVDQuantTensorQuantizer
  3. removed some unused code
  4. awq_lite: disable awq instead of raising errors
  5. code clean up or improvement

Usage

No user facing changes;

Testing

Ran with unit tests and hf_ptq.py int4_awq

Summary by CodeRabbit

  • New Features

    • Added SVD-based LoRA residual support for quantized linear layers and a new SVD-aware quantizer.
  • Improvements

    • Calibration and quantization status messages now appear only from the primary process for clearer logs.
    • Centralized attribute validation, stronger post-calibration checks, and safer reset/state handling.
    • Per-module lifecycle hooks to manage calibration/patching state and context propagation.
  • Bug Fixes

    • Quantization flows (AWQ/AWQ‑Lite/Clip/SVD) now handle missing/invalid paths gracefully with warnings and early exits.

@realAsma realAsma requested a review from a team as a code owner September 23, 2025 17:45
@realAsma realAsma requested a review from jingyu-ml September 23, 2025 17:45
Copy link

coderabbitai bot commented Sep 23, 2025

Warning

Rate limit exceeded

@realAsma has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 16 minutes and 52 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 08c2628 and b8e5ad2.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/model_calib.py (12 hunks)
  • modelopt/torch/quantization/model_quant.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/quant_linear.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)

Walkthrough

Adds rank-aware logging and warnings, per-module AWQLite/AWQClip gating with setup/cleanup hooks and module-name propagation into postprocess; validates TensorQuantizer attributes after calibration using the new validator; removes public svdquant LoRA properties from TensorQuantizer and introduces SVD LoRA support via SVDQuantTensorQuantizer and SVDQuantLinear.

Changes

Cohort / File(s) Summary
Rank-aware logging & AWQ gating
modelopt/torch/quantization/model_calib.py
Replace plain print with print_rank_0 and warnings.warn; add per-module AWQLiteHelper with setup()/cleanup() and is_enabled gating that patches/unpatches forwards and manages input_quantizer axis state; propagate module name into postprocess(module, name); add guards/warnings for missing forward_loop and early-exit paths; convert status messages to rank-aware outputs.
Calibration validation
modelopt/torch/quantization/model_quant.py
After the forward pass during calibrate, iterate model.named_modules() and validate each TensorQuantizer attributes (e.g., _amax, _pre_quant_scale, bias) via validate_attr(..., warn_error=True, name=mod_name); do this after forward_loop handling and before CUDA/cache messaging.
SVD LoRA in linear modules
modelopt/torch/quantization/nn/modules/quant_linear.py
Add SVDQuantTensorQuantizer exposing LoRA buffers management; extend SVDQuantLinear with _setup() to swap quantizer class, _apply_pre_quant_scale, _compute_lora_residual, augmented forward to apply pre-quant scale and optionally add LoRA residual, and fold_weight updates/cleanup to handle LoRA residuals and remove LoRA attributes after folding.
TensorQuantizer refactor & validation
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Move GLOBALS import to module top-level; remove public svdquant_lora_a/b properties; add validate_attr(attr_value, attr_name, raise_error, warn_error, name) central validator; replace several assert checks with ValueError; call reset_bias() in reset_amax; update export/validation and extra_repr to use validate_attr.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Cal as Calibrator
  participant AW as AWQFlow (awq_lite / awq_clip)
  participant M as Module(s)
  participant L as Logger (rank-0)

  User->>Cal: calibrate(model, dataloader)
  Cal->>AW: check is_enabled
  alt disabled or invalid config
    AW-->>L: warn and skip AWQ flow
  else
    AW->>AW: setup()  -- patch module forwards, set per-module flags
    Cal->>M: run forward_loop()
    AW->>M: postprocess(module, name) for each module
    AW->>AW: cleanup() -- restore forwards/attributes
    AW-->>L: info AWQ flow complete
  end
Loading
sequenceDiagram
  autonumber
  participant In as Input
  participant S as SVDQuantLinear
  participant Q as SVDQuantTensorQuantizer
  participant L as LinearOp
  participant Out as Output

  In->>S: forward(x)
  S->>Q: check _pre_quant_scale and svdquant_lora_a/b
  alt LoRA present
    S->>S: x_scaled = _apply_pre_quant_scale(x)
    S->>S: R = _compute_lora_residual(x_scaled)
    S->>L: main linear forward (pre-quant temporarily disabled)
    L-->>S: y_main
    S->>Out: y = y_main + R
  else
    S->>L: main linear forward
    L-->>Out: y
  end
Loading
sequenceDiagram
  autonumber
  participant Cal as calibrate()
  participant Mods as Modules
  participant TQ as TensorQuantizer
  participant L as Logger (rank-0)

  Cal->>Mods: run forward_loop()
  Cal->>Mods: iterate modules
  loop each module with TensorQuantizer
    Cal->>TQ: validate_attr(_amax, "amax", warn_error=True, name=mod_name)
    Cal->>TQ: validate_attr(_pre_quant_scale, "pre_quant_scale", warn_error=True, name=mod_name)
    Cal->>TQ: validate_attr(_bias_value, "bias", warn_error=True, name=mod_name)
  end
  Cal-->>L: info validation complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I thump on rank‑zero logs with glee,
Patch the forwards, set LoRA free.
I check each scale and tidy space,
Warn, cleanup, and restore the place.
Hop—calibration finished, carrot‑grace. 🥕🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.54% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly captures the primary performance improvement of removing the amax validation in TensorQuantizer.forward to reduce CPU–GPU synchronization and also highlights the introduction of a new SVDQuantTensorQuantizer class and minor cleanup in a single concise sentence. It directly reflects the main objectives of the pull request without extraneous details or noise. This phrasing will help reviewers quickly understand the scope and focus of the changes when browsing PR history.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma requested review from jingyu-ml and removed request for jingyu-ml September 23, 2025 17:51
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/model_calib.py (1)

518-531: Bug: undefined name ‘module’ inside update_best_params

Should reference self, not module. This crashes postprocess.

-def update_best_params(self):
-    if not module.awq_lite.is_enabled:
-        return
-    self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get)
-    self.awq_lite.best_scale = get_scale(
-        self.awq_lite.act_scale,
-        self.awq_lite.weight_scale,
-        self.awq_lite.best_alpha,
-        (
-            self.parallel_state.tensor_parallel_group
-            if is_quantized_column_parallel_linear(self)
-            else None
-        ),
-    )
+def update_best_params(self):
+    if not self.awq_lite.is_enabled:
+        return
+    self.awq_lite.best_alpha = min(self.awq_lite.loss, key=self.awq_lite.loss.get)
+    self.awq_lite.best_scale = get_scale(
+        self.awq_lite.act_scale,
+        self.awq_lite.weight_scale,
+        self.awq_lite.best_alpha,
+        (
+            self.parallel_state.tensor_parallel_group
+            if is_quantized_column_parallel_linear(self)
+            else None
+        ),
+    )
🧹 Nitpick comments (3)
modelopt/torch/quantization/model_quant.py (1)

144-145: Docstring typo: s/quaTruent_cfg/quant_cfg

Minor copy fix.

-    performs calibration as specified by ``quaTruent_cfg``.
+    performs calibration as specified by ``quant_cfg``.
modelopt/torch/quantization/nn/modules/quant_linear.py (1)

120-123: Guard _setup against missing weight_quantizer

If _setup runs before weight_quantizer exists, this will raise. Guard and fallback to super.

 def _setup(self):
     """Overrides and bypass the _setup function."""
-    self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
+    if not hasattr(self, "weight_quantizer"):
+        super()._setup()
+    self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

507-524: Make validate_attr robust to non-tensor numerics and NaN

Coerce numerics to tensors; prefer isfinite to catch NaNs. Non-numeric types (e.g., dict) should early-return.

 def validate_attr(
     self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name=""
 ):
     """Validate attribute."""
-    attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None)
+    attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None)
     if attr_value is None:
         return True
-    is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value))
+    # Coerce numeric to tensor; skip validation for non-numeric types (e.g., config dicts)
+    if not isinstance(attr_value, torch.Tensor):
+        try:
+            attr_value = torch.as_tensor(attr_value)
+        except Exception:
+            return True
+    is_valid = torch.all(attr_value >= 0) and torch.all(torch.isfinite(attr_value))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7d5f636 and 4623fb0.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/model_calib.py (13 hunks)
  • modelopt/torch/quantization/model_quant.py (2 hunks)
  • modelopt/torch/quantization/nn/modules/quant_linear.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/model_quant.py (2)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • validate_attr (507-523)
modelopt/torch/quantization/model_calib.py (5)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (8)
  • axis (283-285)
  • axis (288-290)
  • is_enabled (395-397)
  • forward (843-943)
  • disable (399-404)
  • validate_attr (507-523)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/utils/network.py (2)
  • bind_forward_method (634-653)
  • unpatch_forward_method (656-660)
modelopt/torch/trace/symbols.py (2)
  • disable (108-135)
  • named_modules (444-447)
modelopt/torch/quantization/utils.py (1)
  • enable_weight_access_and_writeback (424-443)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • TensorQuantizer (65-1108)
modelopt/torch/quantization/nn/modules/quant_module.py (3)
  • QuantLinearConvBase (129-169)
  • _setup (118-126)
  • _setup (163-169)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (23)
modelopt/torch/quantization/model_calib.py (15)

28-28: LGTM: rank-aware logging import

print_rank_0 improves multi-rank UX.


372-379: LGTM: rank-aware warnings in smoothquant

Non-blocking modules are cleanly skipped with informative, rank-0 messages.

Also applies to: 389-389


440-443: LGTM: awq_lite gracefully disabled without forward_loop

Better UX than hard error.


459-472: LGTM: AWQ-lite setup and gating

Forward patching and per-channel axis gating are correct. Input quantizer temporarily disabled is fine.


473-478: LGTM: Cleanup restores forward and removes calib flag

Prevents lingering patched state.


542-557: LGTM: Cache-mode statistics collection

Correct handling for DTensor, act scale accumulation, and optional input quantizer stats.


572-587: LGTM: Attribute validation gate for awq_lite

Good use of validate_attr on pre-quant scale and amax to fail fast.


584-587: LGTM: Disable AWQ-lite on invalid attributes

Safe fallback.


604-605: LGTM: Rank-aware cache notice


626-627: LGTM: Rank-aware search notice


630-650: Postprocess depends on fix in update_best_params

Flow is correct; ensure the self/module bug is fixed to populate best_scale.


651-664: LGTM: Guard for no-cache or no-search passes with actionable warnings

Clear guidance to users to adjust forward_loop.


665-665: LGTM: Cleanup after AWQ-lite

Avoids leaving patched methods or state.


852-852: LGTM: Rank-aware log in awq_clip


913-913: LGTM: Rank-aware SVD log

modelopt/torch/quantization/nn/modules/quant_linear.py (3)

65-116: LGTM: SVDQuantTensorQuantizer properties align with TensorQuantizer patterns

Setters follow buffer semantics and shape stability. Consider exporting if intended as public.

If this class is meant to be public, add to all.


147-161: LGTM: Forward integrates LoRA residual with pre-quant scaling

Correctly disables pre_quant_scale for main path to avoid double scaling.


163-187: LGTM: Fold weight applies LoRA residual and cleans up buffers

Avoids double-applying residual post-fold.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5)

31-35: LGTM: Centralized ONNX GLOBALS import

Avoids runtime import in forward; supports multiple torch versions.


257-263: LGTM: reset_amax also resets bias

Keeps calibration state consistent.


629-631: LGTM: Explicit ValueError for missing block size

Clearer error than assert.


773-777: LGTM: Raise on input shape change during block-quant

Prevents silent misquantization.


828-829: LGTM: Validate exported amax

Early catch for invalid scales.

@realAsma realAsma force-pushed the asma/remove_validate_amax_from_tq_forward branch from f8bbfe5 to 555ceea Compare September 23, 2025 23:05
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
modelopt/torch/quantization/model_calib.py (2)

372-372: Minor: Consider consistent formatting for warning messages.

For consistency, consider using print_rank_0 for warnings as well since you've migrated other print statements.

-                warnings.warn(f"{name} is not calibrated, skip smoothing")
+                print_rank_0(f"Warning: {name} is not calibrated, skip smoothing")

375-375: Minor: Consider consistent warning format.

Similar to Line 372, consider using print_rank_0 for consistency with the distributed printing pattern.

Also applies to: 378-378

modelopt/torch/quantization/nn/modules/quant_linear.py (1)

120-122: Avoid mutating weight_quantizer.class; instantiate/convert instead.

Setting self.weight_quantizer.__class__ = SVDQuantTensorQuantizer bypasses init/serialization/type-checking and can break debugging/typing — replace with a conversion constructor (e.g. SVDQuantTensorQuantizer.from_tensor_quantizer(self.weight_quantizer)) or perform the transformation in the parent setup. Occurrences: modelopt/torch/quantization/nn/modules/quant_linear.py:122 (this PR); similar patterns found in modelopt/torch/trace/modules/concat.py, modelopt/torch/nas/traced_hp.py, modelopt/torch/opt/dynamic.py — consolidate or document if intentional.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32f8ec9 and 555ceea.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/model_calib.py (13 hunks)
  • modelopt/torch/quantization/model_quant.py (2 hunks)
  • modelopt/torch/quantization/nn/modules/quant_linear.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/model_quant.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • TensorQuantizer (65-1107)
modelopt/torch/quantization/nn/modules/quant_module.py (3)
  • QuantLinearConvBase (129-169)
  • _setup (118-126)
  • _setup (163-169)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • is_enabled (395-397)
  • forward (843-942)
  • disable (399-404)
  • validate_attr (507-523)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/utils/network.py (2)
  • bind_forward_method (634-653)
  • unpatch_forward_method (656-660)
modelopt/torch/quantization/utils.py (1)
  • enable_weight_access_and_writeback (424-443)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
🔇 Additional comments (24)
modelopt/torch/quantization/model_calib.py (15)

28-28: LGTM! Improved distributed printing.

The switch to rank-aware printing using print_rank_0 is a good improvement to avoid duplicate output in distributed training environments.


31-34: LGTM! Proper ONNX GLOBALS compatibility handling.

The conditional import properly handles different PyTorch versions where the ONNX GLOBALS location has changed, ensuring compatibility across PyTorch versions. The fallback import path is appropriate.


389-389: LGTM! Improved distributed output.

Converting to rank-aware printing is appropriate for distributed environments.


440-442: Improved error handling for missing forward_loop.

Good defensive programming - gracefully handling missing forward_loop with a warning rather than letting it fail later. This is better than the previous behavior.


459-471: Enhanced AWQLite lifecycle management.

The addition of setup(), cleanup(), and is_enabled state management improves the robustness of AWQLite. The axis validation on Line 468-470 is a good safeguard against unsupported configurations.


473-477: Proper cleanup implementation.

Good cleanup logic to remove temporary attributes and unpatch forward methods, preventing potential memory leaks or interference.


518-519: Enhanced validation and error handling.

The addition of is_enabled checks and early returns for disabled modules prevents unnecessary processing. This is a good defensive programming pattern.

Also applies to: 542-542, 584-586


571-587: Critical improvement: Validation of quantization parameters.

This is a significant improvement that validates _pre_quant_scale and _amax attributes before proceeding with AWQ processing. The validation prevents processing with invalid quantization parameters that could lead to incorrect results.


604-604: LGTM! Enhanced progress reporting.

The switch to rank-aware logging for progress updates is appropriate for distributed training.

Also applies to: 626-626


630-630: Updated function signature for enhanced logging.

The addition of name parameter to postprocess enables better context-aware logging and error reporting.


636-636: Improved conditional processing and warning messages.

The conditional check for is_enabled before applying pre-quantization scaling (Line 645) and the informative warning messages provide better user feedback about disabled modules.

Also applies to: 645-650


653-662: Enhanced validation for AWQLite execution.

The additional validation checks for num_cache_steps and num_search_steps help identify configuration issues early, providing clear feedback about invalid forward_loop functions.


665-665: Proper cleanup call.

The explicit cleanup() call ensures proper resource management and prevents potential issues.


852-852: LGTM! Consistent rank-aware logging.

Converting AWQClip logging to use rank-aware printing for consistency with other distributed operations.


913-913: LGTM! Enhanced SVDQuant logging.

Adding rank-aware printing with the module name provides better context for SVDQuant operations in distributed environments.

modelopt/torch/quantization/nn/modules/quant_linear.py (4)

65-115: Well-implemented SVDQuantTensorQuantizer class.

The new SVDQuantTensorQuantizer class properly extends TensorQuantizer with LoRA support. The property implementation follows the same pattern as the parent class with proper buffer registration, shape validation, and device handling.


129-146: LGTM! Clean helper method implementations.

Both _apply_pre_quant_scale and _compute_lora_residual are well-implemented with proper conditional checks and clear logic. The use of _not_sequential_quantizers() provides consistent validation.


147-161: Enhanced forward method with LoRA support.

The forward method cleanly handles the SVDQuant LoRA case by:

  1. Detecting LoRA presence
  2. Applying pre-quantization scaling
  3. Computing LoRA residuals
  4. Disabling pre-quant scaling for the main forward pass
  5. Adding residuals to the final output

The logic is sound and maintains proper separation of concerns.


163-186: Proper weight folding with cleanup.

The fold_weight implementation correctly:

  1. Calls the parent fold_weight method
  2. Adds the LoRA residual to the weight when folding
  3. Cleans up the LoRA attributes after folding

This ensures the folded weight contains the complete transformation.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5)

31-34: LGTM! Improved ONNX compatibility.

The conditional import properly handles different PyTorch versions and ensures GLOBALS is available at the module level, which should resolve TorchDynamo compilation issues mentioned in the PR objectives.


507-523: Excellent addition: Centralized attribute validation.

The new validate_attr method provides a clean, centralized way to validate tensor attributes with configurable error handling (warnings vs exceptions). This is much better than scattered validation logic.


262-262: Enhanced reset functionality.

Adding self.reset_bias() to reset_amax() ensures complete quantizer state reset, which is more thorough than just resetting amax.


629-630: Improved error handling with ValueError.

Converting runtime assertions to proper ValueError exceptions improves error clarity and follows Python best practices for input validation.

Also applies to: 773-777


828-828: LGTM! Using centralized validation.

Replacing the old _validate_amax with the new validate_attr method centralizes validation logic and improves maintainability.

@realAsma realAsma force-pushed the asma/remove_validate_amax_from_tq_forward branch from 555ceea to a13c5d4 Compare September 25, 2025 17:24
Copy link

codecov bot commented Sep 25, 2025

Codecov Report

❌ Patch coverage is 72.22222% with 30 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.86%. Comparing base (d649fb8) to head (b8e5ad2).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 72.54% 14 Missing ⚠️
...lopt/torch/quantization/nn/modules/quant_linear.py 75.75% 8 Missing ⚠️
.../torch/quantization/nn/modules/tensor_quantizer.py 60.00% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #357      +/-   ##
==========================================
+ Coverage   73.77%   73.86%   +0.08%     
==========================================
  Files         171      171              
  Lines       17619    17629      +10     
==========================================
+ Hits        12998    13021      +23     
+ Misses       4621     4608      -13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
modelopt/torch/quantization/model_quant.py (1)

108-112: Avoid false positives: don’t validate _bias_value with generic “>= 0”.

validate_attr uses a non-negative check that’s inappropriate for bias (bias can be negative). This will spam warnings and may mislead downstream gating.

Options:

  • Prefer: make validate_attr attribute-aware (see suggested change in tensor_quantizer.py to allow any finite value for bias_value and strictly positive for pre_quant_scale).
  • Or, stop validating _bias_value here and keep only ["_amax", "_pre_quant_scale"].
modelopt/torch/quantization/nn/modules/quant_linear.py (2)

120-123: _setup may run before weight_quantizer exists; guard and/or call super._setup().

If _setup is invoked on a fresh instance, self.weight_quantizer may not be initialized yet. Guard it to avoid AttributeError.

Apply:

-    def _setup(self):
-        """Overrides and bypass the _setup function."""
-        self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
+    def _setup(self):
+        """Ensure weight_quantizer exists and install SVDQuantTensorQuantizer."""
+        if not hasattr(self, "weight_quantizer"):
+            try:
+                super()._setup()
+            except Exception:
+                pass
+        if hasattr(self, "weight_quantizer"):
+            self.weight_quantizer.__class__ = SVDQuantTensorQuantizer

136-145: Robust access to optional LoRA attrs; avoid AttributeError if quantizer class isn’t swapped.

Access svdquant_lora_a/b via getattr(..., default=None).

Apply:

-        if (
-            self._not_sequential_quantizers()
-            and self.weight_quantizer.svdquant_lora_a is not None
-            and self.weight_quantizer.svdquant_lora_b is not None
-        ):
+        lora_a = getattr(self.weight_quantizer, "svdquant_lora_a", None)
+        lora_b = getattr(self.weight_quantizer, "svdquant_lora_b", None)
+        if self._not_sequential_quantizers() and lora_a is not None and lora_b is not None:
-            lora_a = F.linear(input, weight=self.weight_quantizer.svdquant_lora_a)
-            lora_b = F.linear(lora_a, weight=self.weight_quantizer.svdquant_lora_b)
-            return lora_b
+            lora_a_out = F.linear(input, weight=lora_a)
+            lora_b_out = F.linear(lora_a_out, weight=lora_b)
+            return lora_b_out
         return None
@@
-        has_svdquant_lora = (
-            self._not_sequential_quantizers()
-            and self.weight_quantizer.svdquant_lora_a is not None
-            and self.weight_quantizer.svdquant_lora_b is not None
-        )
+        lora_a = getattr(self.weight_quantizer, "svdquant_lora_a", None)
+        lora_b = getattr(self.weight_quantizer, "svdquant_lora_b", None)
+        has_svdquant_lora = self._not_sequential_quantizers() and lora_a is not None and lora_b is not None
@@
-            if (
-                self._not_sequential_quantizers()
-                and self.weight_quantizer.svdquant_lora_a is not None
-                and self.weight_quantizer.svdquant_lora_b is not None
-            ):
-                self.weight.data.copy_(
-                    self.weight
-                    + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
-                )
+            lora_a = getattr(self.weight_quantizer, "svdquant_lora_a", None)
+            lora_b = getattr(self.weight_quantizer, "svdquant_lora_b", None)
+            if self._not_sequential_quantizers() and lora_a is not None and lora_b is not None:
+                self.weight.data.copy_(self.weight + (lora_b @ lora_a))

Also applies to: 147-161, 166-179

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

828-828: OK to validate export amax; consider clone to avoid mutating buffer.

export_amax edits amax in place before validation. If side effects are undesirable, clone before edits.

Apply:

-        if not hasattr(self, "_amax_shape_for_export"):
-            amax = self.amax
+        if not hasattr(self, "_amax_shape_for_export"):
+            amax = self.amax.clone()

And keep subsequent ops on this clone.

modelopt/torch/quantization/model_calib.py (2)

574-589: AWQ-lite gating: rely on attribute-aware validate_attr.

This block assumes “>= 0” semantics; with the proposed validate_attr fix, pre_quant_scale will be strictly positive and amax non-negative. Keep as-is after that change.


633-667: Postprocess robustness: guard when best_scale is None.

If no search steps ran (or disabled), avoid applying a None scale.

Apply:

-        if module.awq_lite.is_enabled:
+        if module.awq_lite.is_enabled and module.awq_lite.best_scale is not None:
             apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 555ceea and 6d58352.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/model_calib.py (13 hunks)
  • modelopt/torch/quantization/model_quant.py (2 hunks)
  • modelopt/torch/quantization/nn/modules/quant_linear.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/model_quant.py (2)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • validate_attr (507-523)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • is_enabled (395-397)
  • forward (843-942)
  • disable (399-404)
  • validate_attr (507-523)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/utils/network.py (2)
  • bind_forward_method (634-653)
  • unpatch_forward_method (656-660)
modelopt/torch/quantization/utils.py (1)
  • enable_weight_access_and_writeback (424-443)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • TensorQuantizer (65-1107)
modelopt/torch/quantization/nn/modules/quant_module.py (3)
  • QuantLinearConvBase (129-169)
  • _setup (118-126)
  • _setup (163-169)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
🔇 Additional comments (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

629-631: Good: raise ValueError for missing block size.

Switching from assert to explicit ValueError improves error clarity and survivability.


773-777: Good: explicit shape validation for block-quant path.

Raising ValueError on shape drift is appropriate here.

modelopt/torch/quantization/model_calib.py (2)

440-443: Graceful AWQ-lite skip when forward_loop is None.

Good change to disable instead of erroring.


855-856: Rank-aware logging: LGTM.

Swapping to print_rank_0 avoids duplicate logs across ranks.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)

257-263: reset_amax now resets bias too — clarify API contract

This changes behavior; make it explicit in the docstring (or split into a separate reset if needed).

-    def reset_amax(self):
-        """Reset amax to None."""
+    def reset_amax(self):
+        """Reset amax to None and reset bias state/calibrator."""

629-631: Good: convert assert to ValueError; consider richer context in the message

Approving the change. Optional: include dim/key context to speed debugging.

-                raise ValueError("block size for dynamic quantization not found.")
+                raise ValueError(
+                    f"Block size for dynamic quantization not found "
+                    f"(inputs.dim={inputs.dim()}, keys={list(self.block_sizes.keys())})."
+                )

828-829: Pass semantic name 'amax' (not '_amax') to validator

If validator keys by attr semantics, prefer 'amax'. If you keep leading underscores, ensure the validator strips them (as suggested).

-        self.validate_attr(attr_name="_amax", attr_value=amax)
+        self.validate_attr(attr_name="amax", attr_value=amax)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6d58352 and a548dea.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

773-777: Good guard against input shape drift

Clear error with explicit shapes. This will prevent hard-to-trace silent mis-quantization.


507-524: Make validate_attr attribute‑aware and avoid logging full tensors

The generic ">= 0 and finite" check is incorrect for several attrs (e.g., bias_value can be negative; pre_quant_scale should be > 0). Also avoid dumping large tensors in messages. This was raised previously.

-    def validate_attr(
-        self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name=""
-    ):
-        """Validate attribute."""
-        attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None)
-        if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta):
-            return True
-        is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value))
-        if is_valid:
-            return True
-        name = f"{name} " if name else ""
-        msg = f"{name}{attr_name} contains invalid values: {attr_value}"
-        if warn_error:
-            warnings.warn(msg)
-        if raise_error:
-            raise ValueError(msg)
-        return False
+    def validate_attr(
+        self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name=""
+    ):
+        """Validate attribute value based on attribute semantics.
+        - amax: finite and >= 0
+        - pre_quant_scale: finite and > 0
+        - bias_value: finite (sign unconstrained)
+        - default: finite
+        """
+        value = attr_value if attr_value is not None else getattr(self, attr_name, None)
+        if value is None or (isinstance(value, torch.Tensor) and value.is_meta):
+            return True
+
+        key = attr_name.lstrip("_")
+        is_finite = torch.isfinite(value).all()
+        if key in ("amax",):
+            is_valid = bool(is_finite and (value >= 0).all())
+        elif key in ("pre_quant_scale",):
+            is_valid = bool(is_finite and (value > 0).all())
+        elif key in ("bias_value",):
+            is_valid = bool(is_finite)
+        else:
+            is_valid = bool(is_finite)
+
+        if is_valid:
+            return True
+
+        # Compact message to avoid dumping large tensors
+        name_prefix = f"{name} " if name else ""
+        shape = tuple(value.shape) if hasattr(value, "shape") else ()
+        any_nan = bool(torch.isnan(value).any())
+        any_inf = bool(torch.isinf(value).any())
+        msg = (
+            f"{name_prefix}{attr_name} invalid (shape={shape}, any_nan={any_nan}, any_inf={any_inf})."
+        )
+        if warn_error:
+            warnings.warn(msg)
+        if raise_error:
+            raise ValueError(msg)
+        return False

@realAsma realAsma force-pushed the asma/remove_validate_amax_from_tq_forward branch from a548dea to 2f7b2eb Compare September 25, 2025 21:39
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (6)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)

65-115: Unify LoRA buffer setters via _set_buffer; keep device/dtype; reduce code duplication

Setters duplicate buffer-registration logic and miss a consistent to(device/dtype) path.

Apply:

     @svdquant_lora_a.setter
     def svdquant_lora_a(self, value):
         """Lora a weights for svdquant."""
         assert value is not None, "svdquant_lora_a cannot be set to None."
-
-        if not isinstance(value, torch.Tensor):
-            value = torch.tensor(value)
-
-        if not hasattr(self, "_svdquant_lora_a"):
-            self.register_buffer("_svdquant_lora_a", value.clone().detach())
-        else:
-            if self._svdquant_lora_a.shape != value.shape:
-                raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.")
-            self._svdquant_lora_a.data.copy_(
-                value.clone().detach().to(self._svdquant_lora_a.device)
-            )
+        if not isinstance(value, torch.Tensor):
+            value = torch.tensor(value)
+        if hasattr(self, "_svdquant_lora_a") and self._svdquant_lora_a.shape != value.shape:
+            raise RuntimeError("Changing shape when setting svdquant_lora_a is not allowed.")
+        value = value.clone().detach().to(
+            getattr(self, "_svdquant_lora_a", value).device if hasattr(self, "_svdquant_lora_a") else value.device
+        )
+        self._set_buffer("_svdquant_lora_a", value)
@@
     @svdquant_lora_b.setter
     def svdquant_lora_b(self, value):
         """Lora b weights for svdquant."""
         assert value is not None, "svdquant_lora_b cannot be set to None."
-
-        if not isinstance(value, torch.Tensor):
-            value = torch.tensor(value)
-
-        if not hasattr(self, "_svdquant_lora_b"):
-            self.register_buffer("_svdquant_lora_b", value.clone().detach())
-        else:
-            if self._svdquant_lora_b.shape != value.shape:
-                raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.")
-            self._svdquant_lora_b.data.copy_(
-                value.clone().detach().to(self._svdquant_lora_b.device)
-            )
+        if not isinstance(value, torch.Tensor):
+            value = torch.tensor(value)
+        if hasattr(self, "_svdquant_lora_b") and self._svdquant_lora_b.shape != value.shape:
+            raise RuntimeError("Changing shape when setting svdquant_lora_b is not allowed.")
+        value = value.clone().detach().to(
+            getattr(self, "_svdquant_lora_b", value).device if hasattr(self, "_svdquant_lora_b") else value.device
+        )
+        self._set_buffer("_svdquant_lora_b", value)

176-179: Avoid extra allocation when folding; use in-place add_

Current code materializes weight + residual then copies. Add in-place is cheaper.

Apply:

-                self.weight.data.copy_(
-                    self.weight
-                    + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
-                )
+                self.weight.data.add_(self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a)
modelopt/torch/quantization/model_calib.py (3)

574-585: Surface validation context and NaN/Inf reasons in logs

Pass name to validate_attr and enable warn_error for better diagnostics; you already warn, but validate_attr can emit attribute-aware context too once improved.

Apply:

-            for tq in [self.input_quantizer, self.weight_quantizer]:
-                for attr in ["_pre_quant_scale", "_amax"]:
-                    if not tq.validate_attr(attr_name=attr):
+            for tq in [self.input_quantizer, self.weight_quantizer]:
+                for attr in ["_pre_quant_scale", "_amax"]:
+                    if not tq.validate_attr(attr_name=attr, warn_error=True, name=name):
                         disable_awq = True
                         warnings.warn(
                             f"awq_lite: {attr} is not valid for {self.awq_lite.name}, skipping awq_lite"
                         )
                         break

633-653: Postprocess: guard against missing best_scale when is_enabled toggles late

If is_enabled flipped off during search, best_scale can remain None. Bail out to max_calibrate in that case.

Apply:

-        if module.awq_lite.is_enabled:
-            apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
+        if module.awq_lite.is_enabled and module.awq_lite.best_scale is not None:
+            apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
+        elif module.awq_lite.is_enabled:
+            warnings.warn(f"awq_lite: No best_scale for {name}, quantizing with max calibration.")
+            max_calibrate(module, lambda module: module.weight_quantizer(module.weight))

656-663: No-op modules in second forward_loop: demote to rank-0 logs

These user-facing warnings will repeat on each rank. Prefer print_rank_0 for consistency.

Apply:

-                warnings.warn(
-                    "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
-                    f" {name}. Please provide a valid `forward_loop` function that can be used to"
-                    " forward data through the model many times."
-                )
+                print_rank_0(
+                    "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
+                    f" {name}. Please provide a valid `forward_loop` function that can be used to"
+                    " forward data through the model many times."
+                )
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

828-828: Consider warn_error=True when validating export amax

Surfacing invalid amax (shape/NaN/Inf) during export helps early detection.

Apply:

-        self.validate_attr(attr_name="_amax", attr_value=amax)
+        self.validate_attr(attr_name="_amax", attr_value=amax, warn_error=True)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a548dea and 2f7b2eb.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/model_calib.py (13 hunks)
  • modelopt/torch/quantization/model_quant.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/quant_linear.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/model_quant.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • TensorQuantizer (65-1107)
modelopt/torch/quantization/nn/modules/quant_module.py (3)
  • QuantLinearConvBase (129-169)
  • _setup (118-126)
  • _setup (163-169)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (8)
  • axis (283-285)
  • axis (288-290)
  • is_enabled (395-397)
  • forward (843-942)
  • disable (399-404)
  • validate_attr (507-523)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/utils/network.py (2)
  • bind_forward_method (634-653)
  • unpatch_forward_method (656-660)
modelopt/torch/quantization/utils.py (1)
  • enable_weight_access_and_writeback (424-443)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

31-35: Harden GLOBALS import across PyTorch versions; add safe fallback

Private import paths change across versions; current hasattr check can still fail. Use try/except and a dummy fallback to avoid import-time crashes.

Apply:

-if hasattr(torch.onnx, "_globals"):
-    from torch.onnx._globals import GLOBALS
-else:  # torch >= 2.9
-    from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
+try:
+    from torch.onnx._globals import GLOBALS  # older PyTorch
+except Exception:
+    try:
+        from torch.onnx._internal.torchscript_exporter._globals import GLOBALS  # PyTorch >= 2.9
+    except Exception:
+        class _DummyGlobals:
+            in_onnx_export = False
+        GLOBALS = _DummyGlobals()

507-524: Make validate_attr attribute-aware; enforce pre_quant_scale > 0; avoid dumping large tensors

Current rule (>=0 and not inf) is wrong for pre_quant_scale (must be > 0) and too generic. Also avoid logging entire tensors.

Apply:

-    def validate_attr(
-        self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name=""
-    ):
-        """Validate attribute."""
-        attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None)
-        if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta):
-            return True
-        is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value))
-        if is_valid:
-            return True
-        name = f"{name} " if name else ""
-        msg = f"{name}{attr_name} contains invalid values: {attr_value}"
-        if warn_error:
-            warnings.warn(msg)
-        if raise_error:
-            raise ValueError(msg)
-        return False
+    def validate_attr(
+        self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name=""
+    ):
+        """Validate attribute value based on attribute semantics."""
+        value = attr_value if attr_value is not None else getattr(self, attr_name, None)
+        if value is None or (isinstance(value, torch.Tensor) and getattr(value, "is_meta", False)):
+            return True
+        key = attr_name.lstrip("_")
+        is_finite = torch.isfinite(value).all() if isinstance(value, torch.Tensor) else math.isfinite(float(value))
+        if key == "amax":
+            is_valid = bool(is_finite and (value >= 0).all())
+        elif key == "pre_quant_scale":
+            is_valid = bool(is_finite and (value > 0).all())
+        elif key in ("bias", "bias_value"):
+            is_valid = bool(is_finite)
+        else:
+            is_valid = bool(is_finite)
+        if is_valid:
+            return True
+        name_prefix = f"{name} " if name else ""
+        shape = tuple(value.shape) if hasattr(value, "shape") else ()
+        any_nan = bool(torch.isnan(value).any()) if isinstance(value, torch.Tensor) else False
+        any_inf = bool(torch.isinf(value).any()) if isinstance(value, torch.Tensor) else not is_finite
+        msg = f"{name_prefix}{attr_name} invalid (shape={shape}, any_nan={any_nan}, any_inf={any_inf})."
+        if warn_error:
+            warnings.warn(msg)
+        if raise_error:
+            raise ValueError(msg)
+        return False

Comment on lines 120 to 125
def _setup(self):
"""Overrides and bypass the _setup function."""
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Call super()._setup() before reclassing the quantizer to avoid missing attributes

If _setup bypasses the base implementation, weight_quantizer may never be created, leading to AttributeError. Initialize via super() and then reclass.

Apply:

-    def _setup(self):
-        """Overrides and bypass the _setup function."""
-        self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
+    def _setup(self):
+        """Override setup and upgrade the weight quantizer to SVDQuantTensorQuantizer."""
+        super()._setup()
+        if not isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
+            # Safe in CPython for pure-Python classes with identical layout
+            self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _setup(self):
"""Overrides and bypass the _setup function."""
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
def _setup(self):
"""Override setup and upgrade the weight quantizer to SVDQuantTensorQuantizer."""
super()._setup()
if not isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
# Safe in CPython for pure-Python classes with identical layout
self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/quant_linear.py around lines 120 to
123, the _setup override currently reassigns weight_quantizer's class before the
base class has run, which can leave weight_quantizer uninitialized and raise
AttributeError; fix by calling super()._setup() first to allow the base
implementation to create and initialize weight_quantizer, then reclass
weight_quantizer to SVDQuantTensorQuantizer, preserving the original instance
and any required attributes.

@realAsma realAsma force-pushed the asma/remove_validate_amax_from_tq_forward branch from 2f7b2eb to a72b7ea Compare September 25, 2025 22:32
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
modelopt/torch/quantization/model_quant.py (1)

108-112: Also validate bias_value during post-calib checks (after fixing validate_attr semantics).

Add "_bias_value" to the attributes validated post-calibration to catch invalid bias tensors. Note: ensure TensorQuantizer.validate_attr treats bias semantics correctly first (bias can be negative). Based on learnings

-    for name, module in model.named_modules():
-        if isinstance(module, TensorQuantizer):
-            for attr_name in ["_amax", "_pre_quant_scale"]:
-                module.validate_attr(attr_name=attr_name, warn_error=True, name=name)
+    for name, module in model.named_modules():
+        if isinstance(module, TensorQuantizer):
+            for attr_name in ["_amax", "_bias_value", "_pre_quant_scale"]:
+                module.validate_attr(attr_name=attr_name, warn_error=True, name=name)
modelopt/torch/quantization/nn/modules/quant_linear.py (1)

65-115: Preserve dtype on LoRA buffer updates.

When copying into existing buffers, match both device and dtype to avoid silent dtype flips across setters.

-            self._svdquant_lora_a.data.copy_(
-                value.clone().detach().to(self._svdquant_lora_a.device)
-            )
+            self._svdquant_lora_a.data.copy_(
+                value.clone().detach().to(device=self._svdquant_lora_a.device, dtype=self._svdquant_lora_a.dtype)
+            )
@@
-            self._svdquant_lora_b.data.copy_(
-                value.clone().detach().to(self._svdquant_lora_b.device)
-            )
+            self._svdquant_lora_b.data.copy_(
+                value.clone().detach().to(device=self._svdquant_lora_b.device, dtype=self._svdquant_lora_b.dtype)
+            )
modelopt/torch/quantization/model_calib.py (1)

575-586: Strengthen validation: enforce pre_quant_scale > 0 and catch NaN.

The disable-awq gate relies on TensorQuantizer.validate_attr. Current impl only checks “>=0 and not inf”; it misses NaN and allows zeros for pre_quant_scale. Tighten validate_attr semantics (see separate comment) so this gate reliably disables AWQ when scales are invalid.

After updating validate_attr, no code changes needed here; re-run a small calib pass to confirm AWQ isn’t disabled spuriously and that zero/NaN scales trigger the intended warning path.

Also applies to: 587-590

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

507-523: Make validate_attr attribute-aware; add NaN checks and compact logs.

Current “>=0 and not inf” rule is incorrect for pre_quant_scale (should be strictly > 0) and too strict for bias (can be negative). Also misses NaN. Implement per-attribute semantics and avoid logging full tensors.

-    def validate_attr(
-        self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name=""
-    ):
-        """Validate attribute."""
-        attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None)
-        if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta):
-            return True
-        is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value))
-        if is_valid:
-            return True
-        name = f"{name} " if name else ""
-        msg = f"{name}{attr_name} contains invalid values: {attr_value}"
-        if warn_error:
-            warnings.warn(msg)
-        if raise_error:
-            raise ValueError(msg)
-        return False
+    def validate_attr(
+        self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name=""
+    ):
+        """Validate attribute value based on attribute semantics.
+        - amax: finite and >= 0
+        - pre_quant_scale: finite and > 0
+        - bias_value: finite (sign unconstrained)
+        - default: finite
+        """
+        value = attr_value if attr_value is not None else getattr(self, attr_name, None)
+        if value is None or (isinstance(value, torch.Tensor) and value.is_meta):
+            return True
+        is_finite = torch.isfinite(value).all()
+        key = attr_name.lstrip("_")
+        if key in ("amax",):
+            ok = bool(is_finite and (value >= 0).all())
+        elif key in ("pre_quant_scale",):
+            ok = bool(is_finite and (value > 0).all())
+        elif key in ("bias_value",):
+            ok = bool(is_finite)
+        else:
+            ok = bool(is_finite)
+        if ok:
+            return True
+        # Compact diagnostics
+        name_prefix = f"{name} " if name else ""
+        shape = tuple(value.shape) if hasattr(value, "shape") else ()
+        any_nan = bool(torch.isnan(value).any())
+        any_inf = bool(torch.isinf(value).any())
+        msg = f"{name_prefix}{attr_name} invalid (shape={shape}, any_nan={any_nan}, any_inf={any_inf})."
+        if warn_error:
+            warnings.warn(msg)
+        if raise_error:
+            raise ValueError(msg)
+        return False

828-828: Consider warning on invalid exported amax.

Optionally pass warn_error=True to surface issues during export formatting.

-        self.validate_attr(attr_name="_amax", attr_value=amax)
+        self.validate_attr(attr_name="_amax", attr_value=amax, warn_error=True, name=self.__class__.__name__)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2f7b2eb and a72b7ea.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/model_calib.py (13 hunks)
  • modelopt/torch/quantization/model_quant.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/quant_linear.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T16:40:12.799Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.799Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/model_quant.py
🧬 Code graph analysis (2)
modelopt/torch/quantization/model_quant.py (2)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • validate_attr (507-523)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • TensorQuantizer (65-1107)
modelopt/torch/quantization/nn/modules/quant_module.py (3)
  • QuantLinearConvBase (129-169)
  • _setup (118-126)
  • _setup (163-169)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (6)
modelopt/torch/quantization/model_calib.py (2)

440-443: Good guard: skip AWQ Lite when forward_loop is missing.

Prevents accidental errors and aligns with “disable instead of error” objective.


609-609: Rank-aware messaging looks good.

Using print_rank_0 reduces log spam under DDP.

Also applies to: 631-631, 917-917

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)

31-35: ONNX GLOBALS import fallback (prior suggestion).

A try/except import with a safe dummy fallback is more robust across PyTorch versions.


630-631: Prefer ValueError over assert for missing block size.

Good change; avoids stripping by -O and provides clearer error semantics.


773-777: Clear error on dynamic shape change in block-quant path.

Good guard; prevents silent mis-quantization.

modelopt/torch/quantization/nn/modules/quant_linear.py (1)

120-125: Bug: _setup bypasses base init; weight_quantizer may be undefined.

Accessing self.weight_quantizer before super()._setup() risks AttributeError. Initialize via base and then reclass.

-    def _setup(self):
-        """Overrides and bypass the _setup function."""
-        if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
-            return
-        self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
+    def _setup(self):
+        """Override setup and upgrade the weight quantizer to SVDQuantTensorQuantizer."""
+        super()._setup()
+        if hasattr(self, "weight_quantizer") and not isinstance(
+            self.weight_quantizer, SVDQuantTensorQuantizer
+        ):
+            # Safe for pure-Python classes with identical layout
+            self.weight_quantizer.__class__ = SVDQuantTensorQuantizer

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a72b7ea and 08c2628.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/model_calib.py (13 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/model_calib.py (4)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • forward (843-942)
  • validate_attr (507-523)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/utils/network.py (2)
  • bind_forward_method (634-653)
  • unpatch_forward_method (656-660)
modelopt/torch/quantization/utils.py (1)
  • enable_weight_access_and_writeback (424-443)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

507-523: Make validate_attr attribute-aware (bias can be negative, pre-quant scale must be > 0).

This helper now treats every attribute as “finite and ≥ 0”. That’s fine for amax, but it incorrectly rejects perfectly valid negative bias_value (bias calibration routinely produces signed values) and it lets pre_quant_scale slip through at zero. As soon as warn_error or raise_error is enabled for those attrs we’ll surface false alarms/break flows. Please reinstate per-attribute rules.

     def validate_attr(
         self, attr_value=None, attr_name="amax", raise_error=False, warn_error=False, name=""
     ):
         """Validate attribute."""
-        attr_value = attr_value if attr_value is not None else getattr(self, attr_name, None)
-        if attr_value is None or (isinstance(attr_value, torch.Tensor) and attr_value.is_meta):
-            return True
-        is_valid = torch.all(attr_value >= 0) and not torch.any(torch.isinf(attr_value))
-        if is_valid:
+        value = attr_value if attr_value is not None else getattr(self, attr_name, None)
+        if value is None or (isinstance(value, torch.Tensor) and value.is_meta):
+            return True
+
+        tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
+        key = attr_name.lstrip("_")
+
+        if tensor.is_floating_point() or tensor.is_complex():
+            finite_mask = torch.isfinite(tensor)
+        else:
+            finite_mask = torch.ones_like(tensor, dtype=torch.bool)
+
+        if key in {"amax"}:
+            valid_mask = finite_mask & (tensor >= 0)
+        elif key in {"pre_quant_scale"}:
+            valid_mask = finite_mask & (tensor > 0)
+        elif key in {"bias", "bias_value"}:
+            valid_mask = finite_mask
+        else:
+            valid_mask = finite_mask
+
+        if valid_mask.all():
             return True
         name = f"{name}." if name else ""
-        msg = f"{name}{attr_name} contains invalid values: {attr_value}"
+        any_nan = bool(tensor.is_floating_point() and torch.isnan(tensor).any())
+        any_inf = bool(tensor.is_floating_point() and torch.isinf(tensor).any())
+        shape = tuple(tensor.shape)
+        msg = (
+            f"{name}{attr_name} contains invalid values "
+            f"(shape={shape}, any_nan={any_nan}, any_inf={any_inf})"
+        )
         if warn_error:
             warnings.warn(msg)
         if raise_error:
             raise ValueError(msg)
         return False

…u sync, new SVDQuantTensorQuantizer, minor code clean up

Signed-off-by: realAsma <[email protected]>
@realAsma realAsma force-pushed the asma/remove_validate_amax_from_tq_forward branch from 08c2628 to b8e5ad2 Compare September 26, 2025 18:10
@realAsma realAsma requested a review from mxinO September 26, 2025 18:13
Copy link
Contributor

@kinjalpatel27 kinjalpatel27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@realAsma realAsma merged commit 2b13b67 into main Sep 26, 2025
27 checks passed
@realAsma realAsma deleted the asma/remove_validate_amax_from_tq_forward branch September 26, 2025 20:48
yeyu-nvidia pushed a commit that referenced this pull request Oct 1, 2025
… new SVDQuantTensorQuantizer, minor code clean up (#357)

Signed-off-by: realAsma <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants