Skip to content

Conversation

jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Sep 15, 2025

What does this PR do?

Type of change: new feature

Overview:

  1. Here we added the support of exporting the dynamic block quantization in ONNX for both mha and liner layer.
  2. Fixed a minor bug in diffusion example.

Usage

FP8_SAGE_DEFAULT_CONFIG = {
    "quant_cfg": {
        "*weight_quantizer": {"num_bits": (4, 3), "axis": None},
        "*input_quantizer": {"num_bits": (4, 3), "axis": None},
        "*output_quantizer": {"enable": False},
        "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}},
        "*softmax_quantizer": {
            "num_bits": (4, 3),
            "axis": None,
        },
        "default": {"enable": False},
    },
    "algorithm": "max",
}

mtq.quantize(model, FP8_SAGE_DEFAULT_CONFIG, forward_func)

torch.onnx.export(model, ...) # you can follow the diffusion example at https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/diffusers/quantization

Testing

  1. This is an evaluation feature since the TRT kernel isn’t ready. No test cases are required at this time, we will add the test case next month.
  2. However, we can continue relying on the existing FP8 per-tensor test cases as usual.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?:No

Additional Information

Summary by CodeRabbit

  • New Features

    • Optional FP8 blockwise quantization for tensors and attention (Q/K/V), configurable at runtime and preserved in ONNX export.
    • New FP8 SAGE preset enabling dynamic blockwise quantization for attention matmuls.
  • Improvements

    • MHA quantization driven by runtime configuration instead of a class default.
    • Per-dimension block-size support propagated end-to-end through attention, export, and symbolic paths.
    • Cleanup to avoid lingering temporary quantization state after forward calls.

@jingyu-ml jingyu-ml self-assigned this Sep 15, 2025
@jingyu-ml jingyu-ml requested review from a team as code owners September 15, 2025 23:06
Copy link

coderabbitai bot commented Sep 15, 2025

Walkthrough

Adds optional FP8 blockwise quantization end-to-end: new FP8_SAGE_DEFAULT_CONFIG, runtime-driven quantize_mha flag, propagation of per-tensor vs dynamic block shapes through diffusers attention and ONNX symbolic/export paths, TensorQuantizer and ScaledE4M3 support for block sizes, and ONNX helpers for blockwise FP8 quantize/dequantize.

Changes

Cohort / File(s) Summary
Diffusers config & usage
examples/diffusers/quantization/config.py, examples/diffusers/quantization/quantize.py
Adds FP8_SAGE_DEFAULT_CONFIG with dynamic QKV block settings; quantize.py now reads quant_config.quantize_mha at runtime (uses instance flag rather than class default).
ONNX FP8 export (blockwise support)
modelopt/torch/quantization/export_onnx.py
Adds _fp8_block_quantize and _fp8_block_dequantize. export_fp8 accepts `amax: float
Tensor quantizer internals
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Adds TensorQuantizer._get_block_sizes_list(self, shape); stores _original_input_shape during forward and removes it after; FP8 fake-quant path derives per-dimension block_sizes_list and passes it into scaled_e4m3.
Diffusers FP8 SDPA path
modelopt/torch/quantization/plugins/diffusers.py
Handles dynamic vs non-dynamic Q/K/V quantizers; computes q/k/v block sizes and forwards them. Updates _QuantAttention and FP8SDPA forward and symbolic signatures to accept q_block_shape, k_block_shape, v_block_shape.
Tensor quant function + symbolic
modelopt/torch/quantization/tensor_quant.py
ScaledE4M3Function parse_args/symbolic extended to include block_sizes; forward adds block_sizes param; ONNX export now passes block_sizes to export_fp8. Backward call updated to account for extra arg.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant QC as QuantConfig
  participant Attn as _QuantAttention/FP8SDPA
  participant Sym as FP8SDPA.symbolic
  participant Export as export_fp8_mha
  participant BlockHelpers as _fp8_block_quant/_fp8_block_dequant

  User->>QC: load FP8_SAGE_DEFAULT_CONFIG
  User->>Attn: forward(query, key, value, ...)
  Attn->>Attn: detect dynamic vs non-dynamic Q/K/V
  Attn->>Attn: compute q/k/v block_shapes via _get_block_sizes_list()
  Attn->>Sym: symbolic(..., q_block_shape, k_block_shape, v_block_shape)
  Sym->>Export: export_fp8_mha(..., q_block_shape, k_block_shape, v_block_shape)
  alt block shapes present
    Export->>BlockHelpers: _fp8_block_quantize(Q/K/V, block_shape)
    BlockHelpers-->>Export: quantized uint8 + scales
    Export->>BlockHelpers: _fp8_block_dequantize(..., block_shape)
    BlockHelpers-->>Export: dequantized tensors
  else no block shapes
    Export->>Export: per-tensor FP8 quantize/dequantize
  end
  Export-->>User: ONNX graph with FP8 (blockwise or per-tensor)
Loading
sequenceDiagram
  autonumber
  participant TensorQ as TensorQuantizer
  participant SE4 as ScaledE4M3Function
  participant Export as export_fp8
  participant BlockHelpers as _fp8_block_quant/_fp8_block_dequant

  TensorQ->>TensorQ: _get_block_sizes_list(_original_input_shape)
  TensorQ->>SE4: forward(x, scale, amax, block_sizes_list, ...)
  alt ONNX export path
    SE4->>Export: export_fp8(..., amax=None|float, block_sizes)
    opt block_sizes provided
      Export->>BlockHelpers: _fp8_block_quantize/_fp8_block_dequantize
    end
  else eager fake-quant
    SE4->>SE4: apply fake-quant with block sizes
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble bytes in tidy blocks and map each tiny shape,
Q, K, V hop in tiles and glide—no scale shall escape.
I stitch the graph with nimble paws and fork each export trail,
Blockwise carrots lined in rows—ONNX smells like ale.
Hop—export done; I twitch my nose and tuck a crunchy tail. 🥕🐇

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title “FP8 Block quantize onnx export support” accurately and concisely summarizes the primary change of adding FP8 block quantization export capabilities to ONNX, which is the core objective of the pull request. It clearly communicates the feature being introduced without extraneous details or vague terminology. The phrasing is focused and matches the PR’s intent to support blockwise FP8 quantization in the ONNX export path.
Docstring Coverage ✅ Passed Docstring coverage is 89.47% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jingyux/block-quant-onnx

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jingyu-ml jingyu-ml marked this pull request as draft September 15, 2025 23:06
Copy link

copy-pr-bot bot commented Sep 15, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml force-pushed the jingyux/block-quant-onnx branch from e4d1775 to 071f167 Compare September 15, 2025 23:14
@jingyu-ml jingyu-ml requested a review from kaix-nv September 15, 2025 23:15
@jingyu-ml jingyu-ml marked this pull request as ready for review September 15, 2025 23:15
Signed-off-by: Jingyu Xin <[email protected]>
Copy link

codecov bot commented Sep 15, 2025

Codecov Report

❌ Patch coverage is 17.07317% with 34 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.68%. Comparing base (70abfb4) to head (f80b847).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/export_onnx.py 8.33% 22 Missing ⚠️
.../torch/quantization/nn/modules/tensor_quantizer.py 28.57% 10 Missing ⚠️
modelopt/torch/quantization/tensor_quant.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #324      +/-   ##
==========================================
- Coverage   73.79%   73.68%   -0.11%     
==========================================
  Files         171      171              
  Lines       17583    17616      +33     
==========================================
+ Hits        12975    12981       +6     
- Misses       4608     4635      +27     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (11)
examples/diffusers/quantization/config.py (1)

39-39: Fix spacing inconsistency in configuration.

The configuration has inconsistent spacing after commas. Line 39 has a missing space after the comma between (4, 3) and "block_sizes".

-        "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}},
+        "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}},
modelopt/torch/quantization/export_onnx.py (2)

237-263: Consider adding validation for block_sizes parameter.

The new _fp8_block_quantize function should validate the structure and values of the block_sizes parameter to prevent runtime errors during ONNX export.

Add validation at the beginning of the function:

 def _fp8_block_quantize(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
     trt_high_precision_dtype: str,
     block_sizes: list,
 ):
     """Helper Function for Quantization."""
+    if not isinstance(block_sizes, list) or not block_sizes:
+        raise ValueError(f"block_sizes must be a non-empty list, got {block_sizes}")
+    if not all(isinstance(b, int) and b > 0 for b in block_sizes):
+        raise ValueError(f"All block sizes must be positive integers, got {block_sizes}")
+        
     output_shape = sym_help._get_tensor_sizes(inputs)

534-535: Fix typo in comment.

-        # We cannot do block quant for the softmax's output 
+        # We cannot do block quant for the softmax's output
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

653-653: Remove trailing whitespace from blank lines.

These blank lines contain unnecessary whitespace which violates Python style guidelines.

-        
+
         Args:
             shape: The tensor shape to use for conversion (can be tuple or torch.Size)
-            
+
         Returns:
             List of block sizes for each dimension, or None if block_sizes is None
-            
+
         Example:

Also applies to: 656-656, 659-659


961-962: Consider thread-safety for the _original_input_shape attribute.

Setting and deleting _original_input_shape as a temporary attribute could cause issues in multi-threaded scenarios where the same quantizer is used by multiple threads simultaneously.

Consider using a context manager or local variable approach instead:

-            setattr(self, "_original_input_shape", inputs.shape)
-            inputs = self._process_for_blockquant(inputs)
+            original_shape = inputs.shape
+            inputs = self._process_for_blockquant(inputs)
+            # Pass original_shape to methods that need it

Alternatively, consider storing it in a thread-local storage if multi-threading support is required.

modelopt/torch/quantization/plugins/diffusers.py (6)

117-124: Fix assertion logic for mixed dynamic/non-dynamic quantizers

The current implementation requires all QKV quantizers to be either dynamic or non-dynamic together. However, the logic flow is problematic - if they're all non-dynamic, scales are computed, but if any is dynamic, it asserts all must be dynamic. This creates a rigid constraint that may not be necessary for all use cases.

Consider refactoring to handle mixed cases more gracefully:

-    if not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic:
-        q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
-        k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
-        v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
-    else:
-        assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type"
-        q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None
+    # Compute scales for non-dynamic quantizers, set None for dynamic ones
+    q_quantized_scale = None if self.q_bmm_quantizer._dynamic else self.q_bmm_quantizer._get_amax(query)
+    k_quantized_scale = None if self.k_bmm_quantizer._dynamic else self.k_bmm_quantizer._get_amax(key)
+    v_quantized_scale = None if self.v_bmm_quantizer._dynamic else self.v_bmm_quantizer._get_amax(value)
+    
+    # Optionally validate consistency if needed
+    dynamic_states = [self.q_bmm_quantizer._dynamic, self.k_bmm_quantizer._dynamic, self.v_bmm_quantizer._dynamic]
+    if len(set(dynamic_states)) > 1:
+        # Log warning or handle mixed dynamic states if necessary
+        pass

122-122: Fix line length violation

Line 122 exceeds the 120 character limit (149 characters).

-        assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type"
+        assert (self.q_bmm_quantizer._dynamic and 
+                self.k_bmm_quantizer._dynamic and 
+                self.v_bmm_quantizer._dynamic), "QKV QDQS must be in the same type"

144-146: Remove trailing whitespace

Line 145 has trailing whitespace after the comma.

             q_block_sizes,
-            k_block_sizes, 
+            k_block_sizes,
             v_block_sizes,

231-233: Inconsistent default values for scale parameters

The scale parameters have inconsistent default values in the symbolic method signature (float | None = 1.0) which doesn't match the forward method where they default to None.

-        q_quantized_scale: float | None = 1.0,
-        k_quantized_scale: float | None = 1.0,
-        v_quantized_scale: float | None = 1.0,
+        q_quantized_scale: float | None = None,
+        k_quantized_scale: float | None = None,
+        v_quantized_scale: float | None = None,

200-202: Consider using TypeAlias for block shape type consistency

The block shape parameters use list | None type annotations repeatedly. Consider defining a type alias for better maintainability and consistency.

Add at the top of the file after imports:

from typing import TypeAlias

BlockShape: TypeAlias = list[int] | None

Then update the signatures:

-        q_block_shape: list | None = None,
-        k_block_shape: list | None = None,
-        v_block_shape: list | None = None,
+        q_block_shape: BlockShape = None,
+        k_block_shape: BlockShape = None,
+        v_block_shape: BlockShape = None,

Also applies to: 236-238


126-128: Add validation for block sizes consistency

The code retrieves block sizes from quantizers but doesn't validate that they're compatible with the actual tensor shapes or the quantization configuration.

Consider adding validation to ensure block sizes are appropriate:

     # Get block sizes lists for each quantizer if needed
     q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)
     k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape)
     v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape)
+    
+    # Validate block sizes if dynamic quantization is enabled
+    if self.q_bmm_quantizer._dynamic and q_block_sizes:
+        for dim, block_size in enumerate(q_block_sizes):
+            if block_size > 1 and query.shape[dim] % block_size != 0:
+                raise ValueError(f"Query dimension {dim} (size {query.shape[dim]}) is not divisible by block size {block_size}")
+    # Similar validation for k and v can be added
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and 071f167.

📒 Files selected for processing (6)
  • examples/diffusers/quantization/config.py (1 hunks)
  • examples/diffusers/quantization/quantize.py (1 hunks)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5 hunks)
  • modelopt/torch/quantization/plugins/diffusers.py (6 hunks)
  • modelopt/torch/quantization/tensor_quant.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • _get_amax (540-549)
  • _get_block_sizes_list (651-672)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • trt_high_precision_dtype (407-409)
  • trt_high_precision_dtype (412-413)
  • block_sizes (289-291)
  • block_sizes (294-296)
  • amax (233-238)
  • amax (241-252)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (7)
  • amax (233-238)
  • amax (241-252)
  • block_sizes (289-291)
  • block_sizes (294-296)
  • trt_high_precision_dtype (407-409)
  • trt_high_precision_dtype (412-413)
  • forward (902-1004)
modelopt/torch/quantization/export_onnx.py (1)
  • export_fp8 (321-342)
🪛 GitHub Actions: Code Quality
modelopt/torch/quantization/plugins/diffusers.py

[error] 122-122: E501 Line too long (149 > 120).


[error] 134-136: mypy: Item 'str' of 'str | float | None' has no attribute 'shape' [union-attr] (lines 134-136).

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

[warning] 653-653: W293 Blank line contains whitespace.


[warning] 656-656: W293 Blank line contains whitespace.


[warning] 659-659: W293 Blank line contains whitespace.


[warning] 959-959: RUF003 Comment contains ambiguous RIGHT SINGLE QUOTATION MARK ’. Did you mean `` (GRAVE ACCENT)?

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
🔇 Additional comments (6)
examples/diffusers/quantization/quantize.py (1)

942-942: LGTM! Runtime configuration for MHA quantization is properly handled.

The change from QuantizationConfig.quantize_mha to quant_config.quantize_mha correctly uses the runtime configuration for MHA quantization. This aligns with the overall pattern of making MHA quantization configurable.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

651-673: LGTM! Well-designed helper method for block size conversion.

The _get_block_sizes_list method is well-implemented with proper handling of both positive and negative dimension indices. The documentation with examples clearly explains its purpose.

modelopt/torch/quantization/export_onnx.py (1)

298-298: Incorrect — default block_sizes is not exercised by current callers

Call sites found in modelopt/torch/quantization/export_onnx.py pass block_sizes into _fp8_block_dequantize, so the hardcoded [1,1,128,1] default isn't used; no change required.

Likely an incorrect or invalid review comment.

modelopt/torch/quantization/tensor_quant.py (1)

415-415: No action required — parse_args('i' → 'is') is backward-compatible.

'is' accepts both integer and integer-list, so existing callers that pass integers remain valid (repo tests call scaled_e4m3(..., 4, 3)).

modelopt/torch/quantization/plugins/diffusers.py (2)

87-148: Good implementation of block quantization propagation!

The changes effectively extend the FP8 quantization path to support dynamic block quantization:

  • Proper handling of dynamic vs non-dynamic quantizers
  • Clean propagation of block shape parameters through the call stack
  • Maintains backward compatibility with existing code

221-221: parse_args: 'is' is correct for int-list (or None)

Matches existing usage in modelopt/torch/quantization/tensor_quant.py — the three "is" entries correctly map the three block-shape parameters to int[] | None.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

959-963: Ensure _original_input_shape is always cleaned up (wrap in try/finally)

State can leak if an exception occurs between setting and deleting _original_input_shape. Wrap the forward section after setup in try/finally.

-        if (
+        cleanup_original_shape = False
+        if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
             # Reshape is required if the logic isnt handled in the simulation kernel
             self._setup_for_blockquant(inputs)
             setattr(self, "_original_input_shape", inputs.shape)
+            cleanup_original_shape = True
             inputs = self._process_for_blockquant(inputs)
 
-        outputs = inputs
+        try:
+            outputs = inputs
             ...
-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
+        finally:
+            if cleanup_original_shape and hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")

Also applies to: 1002-1003

🧹 Nitpick comments (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

681-699: Comment is misleading within dynamic-only branch

The comment says “Double scale Block quantization, including dynamic and static block quantization” but this branch executes only when type == "dynamic". Tighten the comment to avoid confusion.

-            # Double scale Block quantization, including dynamic and static block quantization
+            # Dynamic double-scale block quantization path
modelopt/torch/quantization/plugins/diffusers.py (2)

117-132: Align QKV mode and error message; avoid computing scales in export path

  • Message “QKV QDQS must be in the same type” is unclear. Make it explicit: “Q, K, and V quantizers must all be dynamic or all be static.”
  • Skip _get_amax when exporting; it’s unused at runtime and can be None for dynamic. Guard by torch.onnx.is_in_onnx_export().
-    if (
+    if (
         not self.q_bmm_quantizer._dynamic
         and not self.k_bmm_quantizer._dynamic
         and not self.v_bmm_quantizer._dynamic
     ):
-        q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
-        k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
-        v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
+        if not torch.onnx.is_in_onnx_export():
+            q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
+            k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
+            v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
+        else:
+            q_quantized_scale = k_quantized_scale = v_quantized_scale = None
     else:
         assert (
             self.q_bmm_quantizer._dynamic
             and self.k_bmm_quantizer._dynamic
             and self.v_bmm_quantizer._dynamic
-        ), "QKV QDQS must be in the same type"
+        ), "Q, K, and V quantizers must all be dynamic or all be static."
         q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None

133-137: Using a private helper across modules; consider promoting to public API

Calling _get_block_sizes_list from another module couples to a private method. Expose it as a public helper (e.g., get_block_sizes_list or a util function) to avoid brittle dependencies.

Would you like me to extract a small utility, e.g., modelopt/torch/quantization/utils/block_sizes.py:get_block_sizes_list(shape, block_sizes)?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 071f167 and 831c32d.

📒 Files selected for processing (4)
  • examples/diffusers/quantization/config.py (1 hunks)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5 hunks)
  • modelopt/torch/quantization/plugins/diffusers.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/diffusers/quantization/config.py
  • modelopt/torch/quantization/export_onnx.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • _get_amax (540-549)
  • _get_block_sizes_list (651-672)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
🔇 Additional comments (2)
modelopt/torch/quantization/plugins/diffusers.py (2)

151-155: Default disable flag may be surprising

You pass True when _disable_fp8_mha is absent, which disables FP8 MHA by default. Confirm this is intended for evaluation builds, or flip default to False.

Would you like a config flag gate to avoid silent disablement in production?


251-269: Verified — export_fp8_mha signature matches callsite

Definition in modelopt/torch/quantization/export_onnx.py (def export_fp8_mha at ~line 420) includes q_block_shape, k_block_shape, v_block_shape; the call in modelopt/torch/quantization/plugins/diffusers.py passes them — no mismatch found.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

703-716: Fix scaled_e4m3 call-site argument order

scaled_e4m3 now expects block_sizes before bias; update all callers to the new signature.

  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py:707 (new usage)
  • modelopt/torch/quantization/calib/histogram.py:311
  • tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py:55
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py:148, 158, 166, 173, 185, 187, 202

Use: scaled_e4m3(inputs, amax, block_sizes, bias, E, M, ...). If no block_sizes, pass None as the third argument and move bias to the fourth.

♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

959-963: Guarantee cleanup of _original_input_shape on exceptions.

Deletion is unconditional but not exception-safe; wrap the quantization region in try/finally. Also fix the typo “isnt” → “isn't”.

-        if (
+        cleanup_original_input_shape = False
+        if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
-            # Reshape is required if the logic isnt handled in the simulation kernel
+            # Reshape is required if the logic isn't handled in the simulation kernel
             self._setup_for_blockquant(inputs)
             setattr(self, "_original_input_shape", inputs.shape)
+            cleanup_original_input_shape = True
             inputs = self._process_for_blockquant(inputs)
 
-        outputs = inputs
+        try:
+            outputs = inputs
@@
-        if (
+            if (
                 self.block_sizes is not None
                 and self.block_sizes.get("type", None) != "dynamic"
                 and self._fake_quant
-        ):
-            outputs = self._reset_to_original_shape(outputs)
-
-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
-        return outputs
+            ):
+                outputs = self._reset_to_original_shape(outputs)
+            return outputs
+        finally:
+            if cleanup_original_input_shape and hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")

Also applies to: 1002-1003

🧹 Nitpick comments (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

651-673: Type-safety and input validation for _get_block_sizes_list.

Add explicit typing, validate keys, and guard length mismatches to avoid silently passing malformed shapes to downstream ONNX ops.

-def _get_block_sizes_list(self, shape):
+from typing import Sequence
+
+def _get_block_sizes_list(self, shape: Sequence[int] | torch.Size) -> list[int] | None:
@@
-        block_sizes_list = []
-        for dim in range(len(shape)):
+        # Only allow integer axes plus known metadata keys.
+        valid_meta = {"type", "scale_bits", "scale_block_sizes"}
+        assert all(
+            isinstance(k, int) or k in valid_meta for k in self.block_sizes.keys()
+        ), f"Invalid block_sizes keys: {list(self.block_sizes.keys())}"
+
+        rank = len(shape)
+        block_sizes_list: list[int] = []
+        for dim in range(rank):
             # Check both positive and negative dimension indices
-            dim_negative = dim - len(shape)
+            dim_negative = dim - rank
             block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
             block_sizes_list.append(block_size if block_size is not None else 1)
         return block_sizes_list
modelopt/torch/quantization/export_onnx.py (2)

238-265: Validate block_shape rank and surface clearer errors.

Add a rank check before emitting TRT_DynamicQuantize; mis-sized block_shapes currently fall through to TRT with cryptic errors.

 def _fp8_block_quantize(
@@
-    input_type = inputs.type().scalarType()
+    input_type = inputs.type().scalarType()
+    rank = symbolic_helper._get_tensor_rank(inputs)
+    assert rank is not None, "Input rank must be known at export time."
+    assert len(block_sizes) == rank, (
+        f"block_shape length ({len(block_sizes)}) must match input rank ({rank})."
+    )
@@
     quantized_output, scales_output = g.op(
         "trt::TRT_DynamicQuantize",
         inputs,
         block_shape_i=block_sizes,

503-509: Block-shape consistency in FP8 MHA path.

Validate q/k/v block shapes match input ranks; also ensure softmax path never receives a block shape.

-        query_scaled = export_fp8(
-            g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape
-        )
+        assert (q_block_shape is None) or (
+            len(q_block_shape) == symbolic_helper._get_tensor_rank(query_scaled)
+        ), "q_block_shape rank mismatch."
+        query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape)
@@
-        key_transposed_scaled = export_fp8(
-            g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape
-        )
+        assert (k_block_shape is None) or (
+            len(k_block_shape) == symbolic_helper._get_tensor_rank(key_transposed_scaled)
+        ), "k_block_shape rank mismatch."
+        key_transposed_scaled = export_fp8(g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape)
@@
-        # We cannot do block quant for the softmax's output
-        attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None)
+        # We cannot do block quant for the softmax's output
+        attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None)
@@
-        value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)
+        assert (v_block_shape is None) or (
+            len(v_block_shape) == symbolic_helper._get_tensor_rank(value)
+        ), "v_block_shape rank mismatch."
+        value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)

Also applies to: 535-549

modelopt/torch/quantization/plugins/diffusers.py (2)

117-132: Q/K/V quantization mode must match: improve error and skip redundant work.

The assertion is good. Minor: clarify message and avoid computing per-tensor amax if any quantizer is dynamic.

-        ), "QKV QDQS must be in the same type"
+        ), "Q/K/V quantization modes must match: either all dynamic or all static."

133-137: Guard block size list creation when block_sizes is None.

_if a quantizer has no block_sizes, _get_block_sizes_list returns None; that’s fine. Add a quick comment to make intent explicit and future-proof.

-    q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)  # type: ignore[union-attr]
+    # Returns None for per-tensor paths; ONNX export handles that by taking the non-block path.
+    q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)  # type: ignore[union-attr]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 071f167 and 831c32d.

📒 Files selected for processing (4)
  • examples/diffusers/quantization/config.py (1 hunks)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5 hunks)
  • modelopt/torch/quantization/plugins/diffusers.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/diffusers/quantization/config.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • trt_high_precision_dtype (407-409)
  • trt_high_precision_dtype (412-413)
  • block_sizes (289-291)
  • block_sizes (294-296)
  • amax (233-238)
  • amax (241-252)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • _get_amax (540-549)
  • _get_block_sizes_list (651-672)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/torch/quantization/plugins/diffusers.py (1)

229-231: Keep 't' for q/k/v scales — export extracts consts

export_onnx.py already extracts constant floats (uses sym_help._get_const / _maybe_get_const for scale/amax), so the current parse_args ("..., 't','t','t', ...") is fine; only change those three to 'f' if the export_fp8 const-extraction fix is removed. Location: modelopt/torch/quantization/plugins/diffusers.py (symbolic decorator around lines 229–231; same check applies to lines ~241–249).

Signed-off-by: Jingyu Xin <[email protected]>
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please share a sample command for ONNX exporting for a supported model as well in the description.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

649-671: Return None for “metadata‑only” block_sizes and add typing.

Avoid forcing block mode with [1,…,1] when there are no per‑dim entries; this can route ONNX path incorrectly.

Apply:

-    def _get_block_sizes_list(self, shape):
+    def _get_block_sizes_list(self, shape) -> list[int] | None:
         """Convert block_sizes dict to list format based on tensor shape.
@@
-        if self.block_sizes is None:
+        if self.block_sizes is None:
             return None
-
-        block_sizes_list = []
+        # If there are no integer dimension entries with a meaningful block size, treat as no block quant.
+        has_dim_sizes = any(
+            isinstance(k, int) and (v is not None and v != 1) for k, v in self.block_sizes.items()
+        )
+        if not has_dim_sizes:
+            return None
+
+        block_sizes_list: list[int] = []
         for dim in range(len(shape)):
             # Check both positive and negative dimension indices
             dim_negative = dim - len(shape)
             block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
             block_sizes_list.append(block_size if block_size is not None else 1)
         return block_sizes_list

1005-1007: Centralize deletion of _original_input_shape in a finally block.

Move this into a single finally at the end of forward so it runs regardless of success/failure.

Apply:

-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
-        return outputs
+        try:
+            return outputs
+        finally:
+            if hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")
modelopt/torch/quantization/export_onnx.py (3)

294-321: Remove brittle default [1,1,128,1] and validate rank in _fp8_block_dequantize.

Defaulting silently is dangerous and shape-dependent. Require explicit block_sizes and assert correctness. (Same concern raised earlier.)

Apply:

-def _fp8_block_dequantize(
+def _fp8_block_dequantize(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
     scales: torch.Value,
     trt_high_precision_dtype: str,
     otype: str | None = None,
-    block_sizes: list = [1, 1, 128, 1],
+    block_sizes: list,
 ):
     """Helper Function for Dequantization."""
     output_shape = sym_help._get_tensor_sizes(inputs)
+    # Validate block shape
+    rank = sym_help._get_tensor_rank(inputs)
+    assert rank is not None, "Input rank must be known at export time."
+    assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, (
+        f"block_shape length ({len(block_sizes)}) must match input rank ({rank})."
+    )
+    assert all(isinstance(b, int) and b > 0 for b in block_sizes), (
+        "All entries in block_shape must be positive integers."
+    )
+    if otype is None:
+        otype = inputs.type().scalarType()

323-345: Handle non-Python amax safely and validate block_shapes before block Q/DQ path.

float(amax) will break when amax is a graph Value/0‑dim tensor; also assert block_shapes align with input rank before calling block ops. (Echoing prior comment.)

Apply:

 def export_fp8(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
-    amax: float | None,
+    amax: float | None,
     trt_high_precision_dtype: str | None,
     block_sizes: list | None,
 ):
     """Export quantized model to FP8 ONNX."""
-    scale = 1.0 if amax is None else 448.0 / float(amax)
+    if amax is None:
+        scale = 1.0
+    else:
+        amax_const = sym_help._get_const(amax, "f", "amax")
+        # If not a constant at export time, fall back to neutral scale to avoid exporter errors.
+        scale = 1.0 if (amax_const is None or amax_const == 0) else 448.0 / float(amax_const)
@@
-    if not block_sizes:
+    if not block_sizes:
         q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
         return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
     else:
+        # Validate block shape early
+        rank = sym_help._get_tensor_rank(inputs)
+        assert rank is not None, "Input rank must be known at export time."
+        assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, (
+            f"block_shape length ({len(block_sizes)}) must match input rank ({rank})."
+        )
+        assert all(isinstance(b, int) and b > 0 for b in block_sizes), (
+            "All entries in block_shape must be positive integers."
+        )
         q_tensor, scales_output = _fp8_block_quantize(
             g, inputs, trt_high_precision_dtype, block_sizes
         )
         return _fp8_block_dequantize(
             g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes
         )

238-265: Validate block_shape against input rank and values in _fp8_block_quantize.

Guard against mismatched ranks and non-positive entries to avoid invalid custom op attributes at export time.

Apply:

 def _fp8_block_quantize(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
     trt_high_precision_dtype: str,
     block_sizes: list,
 ):
     """Helper Function for Quantization."""
     output_shape = sym_help._get_tensor_sizes(inputs)
 
+    # Validate block shape
+    rank = sym_help._get_tensor_rank(inputs)
+    assert rank is not None, "Input rank must be known at export time."
+    assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, (
+        f"block_shape length ({len(block_sizes)}) must match input rank ({rank})."
+    )
+    assert all(isinstance(b, int) and b > 0 for b in block_sizes), (
+        "All entries in block_shape must be positive integers."
+    )
🧹 Nitpick comments (2)
modelopt/torch/quantization/export_onnx.py (1)

512-518: Pre-validate q/k block_shapes vs tensor ranks to fail fast.

Catch mismatches early instead of deep inside TRT ops.

Apply:

-        query_scaled = export_fp8(
+        # Sanity-check block shapes
+        for name, t, bs in (("q", query_scaled, q_block_shape), ("k", key_transposed_scaled, k_block_shape)):
+            if bs is not None:
+                r = sym_help._get_tensor_rank(t)
+                assert r is not None and len(bs) == r, f"{name}_block_shape must match rank ({r})."
+        query_scaled = export_fp8(
             g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape
         )
@@
-        key_transposed_scaled = export_fp8(
+        key_transposed_scaled = export_fp8(
             g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape
         )
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

962-965: Guarantee cleanup of _original_input_shape via try/finally; fix typo.

Ensure attribute is deleted even on exceptions; also fix “isnt” -> “isn't”.

Apply:

-            # Reshape is required if the logic isnt handled in the simulation kernel
-            self._setup_for_blockquant(inputs)
-            setattr(self, "_original_input_shape", inputs.shape)
-            inputs = self._process_for_blockquant(inputs)
+            # Reshape is required if the logic isn't handled in the simulation kernel
+            cleanup_original_shape = False
+            try:
+                self._setup_for_blockquant(inputs)
+                setattr(self, "_original_input_shape", inputs.shape)
+                cleanup_original_shape = True
+                inputs = self._process_for_blockquant(inputs)
+            except Exception:
+                # Make sure we don't leak transient attributes on failure
+                if cleanup_original_shape and hasattr(self, "_original_input_shape"):
+                    delattr(self, "_original_input_shape")
+                raise
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d2c6e0f and 0af26b2.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • trt_high_precision_dtype (405-407)
  • trt_high_precision_dtype (410-411)
  • block_sizes (287-289)
  • block_sizes (292-294)
  • amax (231-236)
  • amax (239-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/quantization/export_onnx.py (2)

545-549: Softmax stays non-block — good call.

Keeping softmax output per-tensor avoids mismatched block semantics across BMM2.


558-559: Block‑quantizing V path looks correct.

The cast order aligns with the earlier Q/K path.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

701-714: Pass block_sizes to scaled_e4m3 by keyword to avoid positional drift.

rg found only call sites (tests and histogram.py); I couldn't locate the scaled_e4m3 definition—confirm it accepts a block_sizes kwarg. If it does, apply:

         outputs = scaled_e4m3(
             inputs,
             amax,
             self._get_bias(inputs),
             E,
             M,
             self._trt_high_precision_dtype,
             self._pass_through_bwd,
-            block_sizes_list,
+            block_sizes=block_sizes_list,
         )

Signed-off-by: Jingyu Xin <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

649-671: Don’t force “block mode” when block_sizes has only metadata; return None.

Current logic returns [1, ...] even if block_sizes has no per‑dim sizes, unintentionally enabling block handling downstream. Return None when there are no integer keys with v not in {None, 1}. Add typing while here.

-    def _get_block_sizes_list(self, shape):
+    def _get_block_sizes_list(self, shape) -> list[int] | None:
@@
-        if self.block_sizes is None:
-            return None
-
-        block_sizes_list = []
+        if self.block_sizes is None:
+            return None
+        # If there are no per-dimension entries with a real block size, treat as no block quant.
+        has_dim_sizes = any(
+            isinstance(k, int) and (v is not None and v != 1)
+            for k, v in self.block_sizes.items()
+        )
+        if not has_dim_sizes:
+            return None
+
+        block_sizes_list: list[int] = []
         for dim in range(len(shape)):
             # Check both positive and negative dimension indices
             dim_negative = dim - len(shape)
             block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
             block_sizes_list.append(block_size if block_size is not None else 1)
         return block_sizes_list

1006-1007: Guarantee deletion of _original_input_shape with try/finally (prev feedback).

Move deletion into a finally block guarded by the local flag so it always runs.

-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
-        return outputs
+        try:
+            return outputs
+        finally:
+            if cleanup_original_input_shape and hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")
🧹 Nitpick comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

962-966: Make _original_input_shape cleanup exception‑safe: set flag when set, clean in finally.

Set a local flag when creating the attribute so the finally block can reliably clean it even if an exception occurs later.

         if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
             # Reshape is required if the logic is not handled in the simulation kernel
             # Only MX format and NVFP4 reshape are currently supported by the kernel.
             self._setup_for_blockquant(inputs)
-            setattr(self, "_original_input_shape", inputs.shape)
+            setattr(self, "_original_input_shape", inputs.shape)
+            cleanup_original_input_shape = True
             inputs = self._process_for_blockquant(inputs)

Add the flag near the top of forward (before this block):

-        # Rotating the input
+        cleanup_original_input_shape = False
+        # Rotating the input
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0af26b2 and 35f3da2.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

701-704: LGTM: using pre‑reshape shape to derive per‑dim blocks.

Using _original_input_shape avoids mismapping after reshape/flatten.


705-714: Pass block_sizes as a keyword; confirm scaled_e4m3 signature and update callers

Change this call to pass block_sizes by name to avoid positional-argument drift; before merging, confirm the scaled_e4m3 definition accepts a named block_sizes parameter (or update all callers if signature changed). Location: modelopt/torch/quantization/nn/modules/tensor_quantizer.py (around lines 705–714).

             outputs = scaled_e4m3(
                 inputs,
                 amax,
                 self._get_bias(inputs),
                 E,
                 M,
                 self._trt_high_precision_dtype,
                 self._pass_through_bwd,
-                block_sizes_list,
+                block_sizes=block_sizes_list,
             )

Quick verification commands to run locally:

  • rg -nP 'def\s+scaled_e4m3\s*(' -C2
  • rg -nP '\bscaled_e4m3\s*(' -C2
  • rg -nP '\bscaled_e4m3\s*([^)]block_sizes\s=' -n

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

674-687: scaled_e4m3 signature change breaks existing callers—use keyword arguments.

You inserted block_sizes_list as the 3rd positional argument to scaled_e4m3, but other call sites (e.g., modelopt/torch/quantization/calib/histogram.py:311 and test files) still pass bias, E, M positionally. This will cause argument-position mismatches and break tests and ONNX export.

Call scaled_e4m3 with keyword arguments to avoid positional-shift bugs.

             # Convert block_sizes dict to list format
             # Use original input shape if available (before reshaping), otherwise use current shape
             shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape)
             block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes)
-            outputs = scaled_e4m3(
-                inputs,
-                amax,
-                self._get_bias(inputs),
-                E,
-                M,
-                self._trt_high_precision_dtype,
-                self._pass_through_bwd,
-                block_sizes_list,
-            )
+            outputs = scaled_e4m3(
+                inputs=inputs,
+                amax=amax,
+                bias=self._get_bias(inputs),
+                E=E,
+                M=M,
+                trt_high_precision_dtype=self._trt_high_precision_dtype,
+                pass_through_bwd=self._pass_through_bwd,
+                block_sizes=block_sizes_list,
+            )

Based on learnings

modelopt/torch/quantization/plugins/diffusers.py (1)

253-271: Normalize optional block-shapes to Python int-lists for ONNX attrs.

When parse_args passes 'v', convert Values to lists; keep None when absent.

     def symbolic(
         g: "GraphContext",
         query: "torch._C.Value",
@@
-    ):
-        """Symbolic method."""
+    ):
+        """Symbolic method."""
+        # Normalize optional block-shapes to Python int lists for ONNX attributes
+        if q_block_shape is not None:
+            q_block_shape = symbolic_helper._maybe_get_const(q_block_shape, "is")
+        if k_block_shape is not None:
+            k_block_shape = symbolic_helper._maybe_get_const(k_block_shape, "is")
+        if v_block_shape is not None:
+            v_block_shape = symbolic_helper._maybe_get_const(v_block_shape, "is")
         return export_fp8_mha(
             g,
             query,
             key,
             value,
♻️ Duplicate comments (5)
modelopt/torch/quantization/export_onnx.py (2)

299-325: Remove brittle default block_sizes=[1,1,128,1] and validate rank.

A hardcoded default of [1, 1, 128, 1] is unsafe—it assumes a specific tensor rank and shape. Require callers to pass block_sizes explicitly and validate it against the input tensor rank.

 def _fp8_block_dequantize(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
     scales: torch.Value,
     trt_high_precision_dtype: str,
     otype: str | None = None,
-    block_sizes: list = [1, 1, 128, 1],
+    block_sizes: list | None = None,
 ):
     """Helper Function for Dequantization."""
+    assert block_sizes is not None, "block_sizes must be provided for block dequantization."
     output_shape = sym_help._get_tensor_sizes(inputs)
+    assert len(block_sizes) == len(output_shape), (
+        f"block_sizes length ({len(block_sizes)}) must match input rank ({len(output_shape)})."
+    )
     assert trt_high_precision_dtype in (otype, "Float"), (
         "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
     )
     out = g.op(
         "trt::TRT_BlockDequantize",
         inputs,
         scales,
         block_shape_i=block_sizes,
     ).setType(
         inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape)
     )
 
     # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
     # custom ops, so cast the output if needed.
     if trt_high_precision_dtype != otype:
         out = g.op("Cast", out, to_i=onnx_dtype_map[otype])  # type: ignore[index]
     return out

Based on learnings


328-349: Handle tensor/Value amax correctly during symbolic export.

float(amax) on line 336 will fail if amax is a graph Value or 0-dim tensor. Use _get_const to extract Python floats when possible, or handle symbolic amax.

 def export_fp8(
     g: torch.onnx._internal.jit_utils.GraphContext,
     inputs: torch.Value,
     amax: float | None,
     trt_high_precision_dtype: str | None,
     block_sizes: list | None,
 ):
     """Export quantized model to FP8 ONNX."""
-    scale = 1.0 if amax is None else 448.0 / float(amax)
+    # Accept Python floats or Tensor constants; None indicates dynamic/block path not using amax.
+    if amax is None:
+        scale = 1.0
+    else:
+        amax_const = sym_help._get_const(amax, "f", "amax") if hasattr(amax, "node") else amax
+        # If still not constant, fall back to 1.0 to avoid build-time failures.
+        scale = 1.0 if amax_const is None else 448.0 / float(amax_const)
     otype = inputs.type().scalarType()
     if trt_high_precision_dtype is None:
         trt_high_precision_dtype = otype
     if not block_sizes:
         q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
         return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
     else:
         q_tensor, scales_output = _fp8_block_quantize(
             g, inputs, trt_high_precision_dtype, block_sizes
         )
         return _fp8_block_dequantize(
             g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes
         )

Based on learnings

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

622-643: Guard against metadata-only block_sizes dicts returning [1,1,1,1].

If block_sizes contains only non-integer keys like {"type": "dynamic", "scale_bits": (4,3)}, you return [1]*rank, which signals "block mode" downstream even though no dimension has an actual block size. This can incorrectly trigger the FP8 block path.

Return None when there are no integer dimension keys with a block size > 1.

 def _get_block_sizes_list(self, shape):
     """Convert block_sizes dict to list format based on tensor shape.
     
     Args:
         shape: The tensor shape to use for conversion (can be tuple or torch.Size)
     
     Returns:
         List of block sizes for each dimension, or None if block_sizes is None
     
     Example:
         block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1]
     """
     if self.block_sizes is None:
         return None
+    
+    # Check if there are any actual per-dimension block sizes (not just metadata)
+    has_dim_sizes = any(
+        isinstance(k, int) and v not in (None, 1) for k, v in self.block_sizes.items()
+    )
+    if not has_dim_sizes:
+        return None
     
     block_sizes_list = []
     for dim in range(len(shape)):
         # Check both positive and negative dimension indices
         dim_negative = dim - len(shape)
         block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
         block_sizes_list.append(block_size if block_size is not None else 1)
     return block_sizes_list

Based on learnings


975-976: Cleanup of _original_input_shape should be in a try-finally block.

The deletion of _original_input_shape at the end of forward is not guaranteed if an exception occurs during the forward pass. This can leave stale state that corrupts subsequent calls.

Wrap the forward logic that depends on _original_input_shape in a try-finally block to guarantee cleanup.

         if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
             # Reshape is required if the logic is not handled in the simulation kernel
             # Only MX format and NVFP4 reshape are currently supported by the kernel.
             self._setup_for_blockquant(inputs)
             setattr(self, "_original_input_shape", inputs.shape)
             inputs = self._process_for_blockquant(inputs)
 
-        outputs = inputs
-
-        block_size = None
-        if self._if_calib and not self._dynamic:
-            if self._calibrator is None:
-                raise RuntimeError("Calibrator was not created.")
-            # Shape is only known when it sees the first tensor
-            if self.block_sizes is not None and self.block_sizes.get("type", None) == "dynamic":
-                block_size = self.block_sizes.get(-1, None) or self.block_sizes.get(
-                    inputs.dim() - 1, None
-                )
-                assert block_size is not None, "block size for dynamic quantization not found."
-
-            self.collect(inputs)
-
-        if self._if_quant:
-            # Check if the input tensor is contiguous
-            # Non-contiguous tensors will generate incorrect FP4 quantization results
-            if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous():
-                inputs.data = inputs.data.contiguous()
-            if self.fake_quant:
-                outputs = self._fake_quantize(inputs)
-            elif not self._dequantize:
-                outputs = self._real_quantize(inputs)
-            else:
-                raise ValueError(
-                    "self._dequantize is True and self.fake_quant is False. "
-                    "This case should have been handled."
-                )
-
-        if (
-            self.block_sizes is not None
-            and self.block_sizes.get("type", None) != "dynamic"
-            and self._fake_quant
-        ):
-            outputs = self._reset_to_original_shape(outputs)
-
-        if hasattr(self, "_original_input_shape"):
-            delattr(self, "_original_input_shape")
+        try:
+            outputs = inputs
+
+            block_size = None
+            if self._if_calib and not self._dynamic:
+                if self._calibrator is None:
+                    raise RuntimeError("Calibrator was not created.")
+                # Shape is only known when it sees the first tensor
+                if self.block_sizes is not None and self.block_sizes.get("type", None) == "dynamic":
+                    block_size = self.block_sizes.get(-1, None) or self.block_sizes.get(
+                        inputs.dim() - 1, None
+                    )
+                    assert block_size is not None, "block size for dynamic quantization not found."
+
+                self.collect(inputs)
+
+            if self._if_quant:
+                # Check if the input tensor is contiguous
+                # Non-contiguous tensors will generate incorrect FP4 quantization results
+                if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous():
+                    inputs.data = inputs.data.contiguous()
+                if self.fake_quant:
+                    outputs = self._fake_quantize(inputs)
+                elif not self._dequantize:
+                    outputs = self._real_quantize(inputs)
+                else:
+                    raise ValueError(
+                        "self._dequantize is True and self.fake_quant is False. "
+                        "This case should have been handled."
+                    )
+
+            if (
+                self.block_sizes is not None
+                and self.block_sizes.get("type", None) != "dynamic"
+                and self._fake_quant
+            ):
+                outputs = self._reset_to_original_shape(outputs)
+        finally:
+            if hasattr(self, "_original_input_shape"):
+                delattr(self, "_original_input_shape")
+                
         return outputs

Based on learnings

modelopt/torch/quantization/plugins/diffusers.py (1)

231-235: Fix ONNX parse_args: allow optional scales and block-shapes (prevents export crash).

  • q/k/v scales can be None for dynamic quantization; using 't' rejects None.
  • q/k/v block-shapes are optional; using 'is' rejects None.

Update parse_args to accept optionals with 'v' for these 6 entries.

-    @symbolic_helper.parse_args(
-        "v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is"
-    )
+    @symbolic_helper.parse_args(
+        "v", "v", "v", "v", "f", "b", "v", "v", "v", "v", "s", "b", "v", "v", "v"
+    )

Also normalize the 3 block-shape values to Python lists inside symbolic before forwarding (see next comment). Based on learnings.

Also applies to: 243-251

🧹 Nitpick comments (4)
modelopt/torch/quantization/plugins/diffusers.py (4)

120-135: Clarify assertion message for mixed dynamic/static Q/K/V quantizers.

Message is cryptic. Make intent explicit.

-        ), "QKV QDQS must be in the same type"
+        ), "q/k/v quantizers must all be either dynamic or static"

243-251: Adjust symbolic type hints to reflect actual tensor/value types.

Scales are Values/Tensors, not floats; block-shapes are list[int] | None.

-        scale: torch._C.Value | None = None,
-        q_quantized_scale: float | None = 1.0,
-        k_quantized_scale: float | None = 1.0,
-        v_quantized_scale: float | None = 1.0,
+        scale: "torch._C.Value | None" = None,
+        q_quantized_scale: "torch._C.Value | None" = None,
+        k_quantized_scale: "torch._C.Value | None" = None,
+        v_quantized_scale: "torch._C.Value | None" = None,
         high_precision_flag: str = "Half",
         disable_fp8_mha: bool = True,
-        q_block_shape: list | None = None,
-        k_block_shape: list | None = None,
-        v_block_shape: list | None = None,
+        q_block_shape: "list[int] | None" = None,
+        k_block_shape: "list[int] | None" = None,
+        v_block_shape: "list[int] | None" = None,

136-139: Block-shape None path is now safe once parse_args is fixed; optional naming nit.

With the parse_args/symbolic normalization above, passing None here is OK. Consider renaming variables to *_block_shape for consistency with FP8SDPA.forward.

-    q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)  # type: ignore[union-attr]
-    k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape)  # type: ignore[union-attr]
-    v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape)  # type: ignore[union-attr]
+    q_block_shape = self.q_bmm_quantizer._get_block_sizes_list(query.shape)  # type: ignore[union-attr]
+    k_block_shape = self.k_bmm_quantizer._get_block_sizes_list(key.shape)  # type: ignore[union-attr]
+    v_block_shape = self.v_bmm_quantizer._get_block_sizes_list(value.shape)  # type: ignore[union-attr]

Remember to update the call site accordingly (Lines 155-157).


155-157: *If you rename to _block_shape, update the call site accordingly.

-            q_block_sizes,
-            k_block_sizes,
-            v_block_sizes,
+            q_block_shape,
+            k_block_shape,
+            v_block_shape,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 25be640 and f80b847.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/export_onnx.py (6 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (5 hunks)
  • modelopt/torch/quantization/plugins/diffusers.py (6 hunks)
  • modelopt/torch/quantization/tensor_quant.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • block_sizes (293-295)
  • block_sizes (298-300)
  • amax (236-241)
  • amax (244-255)
  • trt_high_precision_dtype (411-413)
  • trt_high_precision_dtype (416-417)
modelopt/torch/quantization/export_onnx.py (1)
  • export_fp8 (328-349)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
  • trt_high_precision_dtype (411-413)
  • trt_high_precision_dtype (416-417)
  • block_sizes (293-295)
  • block_sizes (298-300)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • _get_amax (496-505)
  • _get_block_sizes_list (622-643)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/quantization/tensor_quant.py (3)

409-424: LGTM—ScaledE4M3Function signature correctly extended for block_sizes.

The symbolic and forward signatures are updated consistently:

  • parse_args extended to ("v", "t", "t", "is", "i", "i", "s", "b") to accept block_sizes as int or sequence.
  • export_fp8 call correctly passes block_sizes as the 4th argument.
  • The change aligns with the broader PR goal of propagating block sizes through the FP8 export path.

437-437: LGTM—forward signature extended with block_sizes parameter.

The block_sizes=None parameter is appended to the end of the signature, maintaining backward compatibility for callers that do not pass it. This aligns with the symbolic signature update.


465-465: LGTM—backward num_args updated to 8 to match new signature.

The num_args=8 correctly accounts for the additional block_sizes parameter added to the forward signature. This ensures the backward pass handles all arguments properly.

modelopt/torch/quantization/export_onnx.py (5)

243-269: LGTM—_fp8_block_quantize correctly implements block quantization.

The helper correctly:

  • Casts inputs to the target dtype if needed.
  • Calls trt::TRT_DynamicQuantize with block_shape_i=block_sizes.
  • Returns both quantized output and scales.
  • Sets the output type to uint8 as expected.

444-452: LGTM—export_fp8_mha signature correctly extended for block shapes.

The signature updates add:

  • q_quantized_scale, k_quantized_scale, v_quantized_scale as float | None.
  • q_block_shape, k_block_shape, v_block_shape as list | None.

This aligns with the broader PR goal of propagating per-tensor block shapes through the FP8-MHA export path.


529-534: LGTM—export_fp8 called with q_block_shape for query quantization.

The call correctly passes q_block_shape to export_fp8, enabling block quantization for the query tensor in the FP8-MHA path.


562-563: LGTM—softmax output correctly uses None for block_shape.

The comment "We cannot do block quant for the softmax's output" is accurate, and passing None as the block_shape ensures the softmax output is not block-quantized.


573-573: LGTM—export_fp8 called with v_block_shape for value quantization.

The call correctly passes v_block_shape to export_fp8, enabling block quantization for the value tensor in the FP8-MHA path.

modelopt/torch/quantization/plugins/diffusers.py (1)

268-270: Forwarding of block-shapes to export looks correct.

No issues forwarding q/k/v block-shapes once normalized.

Comment on lines +932 to 936
# Reshape is required if the logic is not handled in the simulation kernel
# Only MX format and NVFP4 reshape are currently supported by the kernel.
self._setup_for_blockquant(inputs)
setattr(self, "_original_input_shape", inputs.shape)
inputs = self._process_for_blockquant(inputs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Storing _original_input_shape as an attribute creates state leakage risks.

Setting _original_input_shape on self in the forward pass can cause issues if an exception is raised before cleanup (line 976) or if multiple forward calls interleave (e.g., during model export or distributed training). The attribute can linger and pollute subsequent forward passes.

Consider using a local variable or a context manager to ensure cleanup even on exceptions.

         if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
             # Reshape is required if the logic is not handled in the simulation kernel
             # Only MX format and NVFP4 reshape are currently supported by the kernel.
             self._setup_for_blockquant(inputs)
-            setattr(self, "_original_input_shape", inputs.shape)
+            original_input_shape = inputs.shape
+            self._original_input_shape = original_input_shape
             inputs = self._process_for_blockquant(inputs)

Then wrap the forward logic in try-finally as noted in the next comment.

Based on learnings

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines 932
to 936, avoid assigning self._original_input_shape inside forward since it leaks
state across calls and during exceptions; instead store the original input shape
in a local variable (e.g., original_input_shape = inputs.shape) or use a
short-lived context manager, and wrap the forward transformation/processing
logic in a try-finally block so any necessary cleanup (previously at line 976)
always runs even if an exception occurs; remove the attribute assignment and
update subsequent references to use the local variable or the context-managed
value.

Comment on lines 231 to 235
@staticmethod
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b")
@symbolic_helper.parse_args(
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is"
)
def symbolic(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Repo-wide check: ensure no other symbolics still use 't' for optional scales.

Run to verify and avoid future breakages.


🏁 Script executed:

#!/bin/bash
# Find symbolics that parse optional scales with 't' (should be 'v')
rg -nP '@symbolic_helper\.parse_args\([^)]*\)' -C2 | rg -n 'parse_args\((.*)\)' -o | sed -E 's/.*parse_args\((.*)\).*/\1/' | awk 'BEGIN{RS=","} /"t"/{print FILENAME ":" NR ": " $0}' FILENAME=/dev/stdin
# Narrow search for q/k/v scale params in this file
rg -n 'parse_args\(' modelopt/torch/quantization/plugins/diffusers.py -n -C1

Length of output: 200


🏁 Script executed:

#!/bin/bash
# List all @symbolic_helper.parse_args calls containing "t" across Python files
rg -nP '@symbolic_helper\.parse_args\([^)]*\"t\"[^)]*\)' -g '*.py' -C1

Length of output: 779


🏁 Script executed:

#!/bin/bash
# Show context around parse_args usages in tensor_quant.py
sed -n '320,335p' modelopt/torch/quantization/tensor_quant.py
sed -n '405,420p' modelopt/torch/quantization/tensor_quant.py
sed -n '500,515p' modelopt/torch/quantization/tensor_quant.py

Length of output: 1170


Replace 't' with 'v' for optional scale parameters in all symbolic definitions

  • modelopt/torch/quantization/tensor_quant.py: FakeTensorQuantFunction.symbolic (parse_args at line 326)
  • modelopt/torch/quantization/tensor_quant.py: ScaledE4M3Function.symbolic (parse_args at line 409)
  • modelopt/torch/quantization/tensor_quant.py: DynamicBlockQuantizationFunction.symbolic (parse_args at line 506)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants