-
Notifications
You must be signed in to change notification settings - Fork 169
FP8 Block quantize onnx export support #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds optional FP8 blockwise quantization end-to-end: new FP8_SAGE_DEFAULT_CONFIG, runtime-driven quantize_mha flag, propagation of per-tensor vs dynamic block shapes through diffusers attention and ONNX symbolic/export paths, TensorQuantizer and ScaledE4M3 support for block sizes, and ONNX helpers for blockwise FP8 quantize/dequantize. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant QC as QuantConfig
participant Attn as _QuantAttention/FP8SDPA
participant Sym as FP8SDPA.symbolic
participant Export as export_fp8_mha
participant BlockHelpers as _fp8_block_quant/_fp8_block_dequant
User->>QC: load FP8_SAGE_DEFAULT_CONFIG
User->>Attn: forward(query, key, value, ...)
Attn->>Attn: detect dynamic vs non-dynamic Q/K/V
Attn->>Attn: compute q/k/v block_shapes via _get_block_sizes_list()
Attn->>Sym: symbolic(..., q_block_shape, k_block_shape, v_block_shape)
Sym->>Export: export_fp8_mha(..., q_block_shape, k_block_shape, v_block_shape)
alt block shapes present
Export->>BlockHelpers: _fp8_block_quantize(Q/K/V, block_shape)
BlockHelpers-->>Export: quantized uint8 + scales
Export->>BlockHelpers: _fp8_block_dequantize(..., block_shape)
BlockHelpers-->>Export: dequantized tensors
else no block shapes
Export->>Export: per-tensor FP8 quantize/dequantize
end
Export-->>User: ONNX graph with FP8 (blockwise or per-tensor)
sequenceDiagram
autonumber
participant TensorQ as TensorQuantizer
participant SE4 as ScaledE4M3Function
participant Export as export_fp8
participant BlockHelpers as _fp8_block_quant/_fp8_block_dequant
TensorQ->>TensorQ: _get_block_sizes_list(_original_input_shape)
TensorQ->>SE4: forward(x, scale, amax, block_sizes_list, ...)
alt ONNX export path
SE4->>Export: export_fp8(..., amax=None|float, block_sizes)
opt block_sizes provided
Export->>BlockHelpers: _fp8_block_quantize/_fp8_block_dequantize
end
else eager fake-quant
SE4->>SE4: apply fake-quant with block sizes
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Signed-off-by: Jingyu Xin <[email protected]>
e4d1775
to
071f167
Compare
Signed-off-by: Jingyu Xin <[email protected]>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #324 +/- ##
==========================================
- Coverage 73.79% 73.68% -0.11%
==========================================
Files 171 171
Lines 17583 17616 +33
==========================================
+ Hits 12975 12981 +6
- Misses 4608 4635 +27 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (11)
examples/diffusers/quantization/config.py (1)
39-39
: Fix spacing inconsistency in configuration.The configuration has inconsistent spacing after commas. Line 39 has a missing space after the comma between
(4, 3)
and"block_sizes"
.- "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}}, + "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}},modelopt/torch/quantization/export_onnx.py (2)
237-263
: Consider adding validation for block_sizes parameter.The new
_fp8_block_quantize
function should validate the structure and values of theblock_sizes
parameter to prevent runtime errors during ONNX export.Add validation at the beginning of the function:
def _fp8_block_quantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, trt_high_precision_dtype: str, block_sizes: list, ): """Helper Function for Quantization.""" + if not isinstance(block_sizes, list) or not block_sizes: + raise ValueError(f"block_sizes must be a non-empty list, got {block_sizes}") + if not all(isinstance(b, int) and b > 0 for b in block_sizes): + raise ValueError(f"All block sizes must be positive integers, got {block_sizes}") + output_shape = sym_help._get_tensor_sizes(inputs)
534-535
: Fix typo in comment.- # We cannot do block quant for the softmax's output + # We cannot do block quant for the softmax's outputmodelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
653-653
: Remove trailing whitespace from blank lines.These blank lines contain unnecessary whitespace which violates Python style guidelines.
- + Args: shape: The tensor shape to use for conversion (can be tuple or torch.Size) - + Returns: List of block sizes for each dimension, or None if block_sizes is None - + Example:Also applies to: 656-656, 659-659
961-962
: Consider thread-safety for the _original_input_shape attribute.Setting and deleting
_original_input_shape
as a temporary attribute could cause issues in multi-threaded scenarios where the same quantizer is used by multiple threads simultaneously.Consider using a context manager or local variable approach instead:
- setattr(self, "_original_input_shape", inputs.shape) - inputs = self._process_for_blockquant(inputs) + original_shape = inputs.shape + inputs = self._process_for_blockquant(inputs) + # Pass original_shape to methods that need itAlternatively, consider storing it in a thread-local storage if multi-threading support is required.
modelopt/torch/quantization/plugins/diffusers.py (6)
117-124
: Fix assertion logic for mixed dynamic/non-dynamic quantizersThe current implementation requires all QKV quantizers to be either dynamic or non-dynamic together. However, the logic flow is problematic - if they're all non-dynamic, scales are computed, but if any is dynamic, it asserts all must be dynamic. This creates a rigid constraint that may not be necessary for all use cases.
Consider refactoring to handle mixed cases more gracefully:
- if not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic: - q_quantized_scale = self.q_bmm_quantizer._get_amax(query) - k_quantized_scale = self.k_bmm_quantizer._get_amax(key) - v_quantized_scale = self.v_bmm_quantizer._get_amax(value) - else: - assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type" - q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None + # Compute scales for non-dynamic quantizers, set None for dynamic ones + q_quantized_scale = None if self.q_bmm_quantizer._dynamic else self.q_bmm_quantizer._get_amax(query) + k_quantized_scale = None if self.k_bmm_quantizer._dynamic else self.k_bmm_quantizer._get_amax(key) + v_quantized_scale = None if self.v_bmm_quantizer._dynamic else self.v_bmm_quantizer._get_amax(value) + + # Optionally validate consistency if needed + dynamic_states = [self.q_bmm_quantizer._dynamic, self.k_bmm_quantizer._dynamic, self.v_bmm_quantizer._dynamic] + if len(set(dynamic_states)) > 1: + # Log warning or handle mixed dynamic states if necessary + pass
122-122
: Fix line length violationLine 122 exceeds the 120 character limit (149 characters).
- assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type" + assert (self.q_bmm_quantizer._dynamic and + self.k_bmm_quantizer._dynamic and + self.v_bmm_quantizer._dynamic), "QKV QDQS must be in the same type"
144-146
: Remove trailing whitespaceLine 145 has trailing whitespace after the comma.
q_block_sizes, - k_block_sizes, + k_block_sizes, v_block_sizes,
231-233
: Inconsistent default values for scale parametersThe scale parameters have inconsistent default values in the symbolic method signature (
float | None = 1.0
) which doesn't match the forward method where they default toNone
.- q_quantized_scale: float | None = 1.0, - k_quantized_scale: float | None = 1.0, - v_quantized_scale: float | None = 1.0, + q_quantized_scale: float | None = None, + k_quantized_scale: float | None = None, + v_quantized_scale: float | None = None,
200-202
: Consider using TypeAlias for block shape type consistencyThe block shape parameters use
list | None
type annotations repeatedly. Consider defining a type alias for better maintainability and consistency.Add at the top of the file after imports:
from typing import TypeAlias BlockShape: TypeAlias = list[int] | NoneThen update the signatures:
- q_block_shape: list | None = None, - k_block_shape: list | None = None, - v_block_shape: list | None = None, + q_block_shape: BlockShape = None, + k_block_shape: BlockShape = None, + v_block_shape: BlockShape = None,Also applies to: 236-238
126-128
: Add validation for block sizes consistencyThe code retrieves block sizes from quantizers but doesn't validate that they're compatible with the actual tensor shapes or the quantization configuration.
Consider adding validation to ensure block sizes are appropriate:
# Get block sizes lists for each quantizer if needed q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) + + # Validate block sizes if dynamic quantization is enabled + if self.q_bmm_quantizer._dynamic and q_block_sizes: + for dim, block_size in enumerate(q_block_sizes): + if block_size > 1 and query.shape[dim] % block_size != 0: + raise ValueError(f"Query dimension {dim} (size {query.shape[dim]}) is not divisible by block size {block_size}") + # Similar validation for k and v can be added
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/diffusers/quantization/config.py
(1 hunks)examples/diffusers/quantization/quantize.py
(1 hunks)modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(5 hunks)modelopt/torch/quantization/plugins/diffusers.py
(6 hunks)modelopt/torch/quantization/tensor_quant.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
_get_amax
(540-549)_get_block_sizes_list
(651-672)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
trt_high_precision_dtype
(407-409)trt_high_precision_dtype
(412-413)block_sizes
(289-291)block_sizes
(294-296)amax
(233-238)amax
(241-252)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (7)
amax
(233-238)amax
(241-252)block_sizes
(289-291)block_sizes
(294-296)trt_high_precision_dtype
(407-409)trt_high_precision_dtype
(412-413)forward
(902-1004)modelopt/torch/quantization/export_onnx.py (1)
export_fp8
(321-342)
🪛 GitHub Actions: Code Quality
modelopt/torch/quantization/plugins/diffusers.py
[error] 122-122: E501 Line too long (149 > 120).
[error] 134-136: mypy: Item 'str' of 'str | float | None' has no attribute 'shape' [union-attr] (lines 134-136).
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
[warning] 653-653: W293 Blank line contains whitespace.
[warning] 656-656: W293 Blank line contains whitespace.
[warning] 659-659: W293 Blank line contains whitespace.
[warning] 959-959: RUF003 Comment contains ambiguous RIGHT SINGLE QUOTATION MARK ’. Did you mean `` (GRAVE ACCENT)?
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
🔇 Additional comments (6)
examples/diffusers/quantization/quantize.py (1)
942-942
: LGTM! Runtime configuration for MHA quantization is properly handled.The change from
QuantizationConfig.quantize_mha
toquant_config.quantize_mha
correctly uses the runtime configuration for MHA quantization. This aligns with the overall pattern of making MHA quantization configurable.modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
651-673
: LGTM! Well-designed helper method for block size conversion.The
_get_block_sizes_list
method is well-implemented with proper handling of both positive and negative dimension indices. The documentation with examples clearly explains its purpose.modelopt/torch/quantization/export_onnx.py (1)
298-298
: Incorrect — default block_sizes is not exercised by current callersCall sites found in modelopt/torch/quantization/export_onnx.py pass block_sizes into _fp8_block_dequantize, so the hardcoded [1,1,128,1] default isn't used; no change required.
Likely an incorrect or invalid review comment.
modelopt/torch/quantization/tensor_quant.py (1)
415-415
: No action required — parse_args('i' → 'is') is backward-compatible.'is' accepts both integer and integer-list, so existing callers that pass integers remain valid (repo tests call scaled_e4m3(..., 4, 3)).
modelopt/torch/quantization/plugins/diffusers.py (2)
87-148
: Good implementation of block quantization propagation!The changes effectively extend the FP8 quantization path to support dynamic block quantization:
- Proper handling of dynamic vs non-dynamic quantizers
- Clean propagation of block shape parameters through the call stack
- Maintains backward compatibility with existing code
221-221
: parse_args: 'is' is correct for int-list (or None)Matches existing usage in modelopt/torch/quantization/tensor_quant.py — the three "is" entries correctly map the three block-shape parameters to int[] | None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
959-963
: Ensure _original_input_shape is always cleaned up (wrap in try/finally)State can leak if an exception occurs between setting and deleting _original_input_shape. Wrap the forward section after setup in try/finally.
- if ( + cleanup_original_shape = False + if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): # Reshape is required if the logic isnt handled in the simulation kernel self._setup_for_blockquant(inputs) setattr(self, "_original_input_shape", inputs.shape) + cleanup_original_shape = True inputs = self._process_for_blockquant(inputs) - outputs = inputs + try: + outputs = inputs ... - if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") + finally: + if cleanup_original_shape and hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape")Also applies to: 1002-1003
🧹 Nitpick comments (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
681-699
: Comment is misleading within dynamic-only branchThe comment says “Double scale Block quantization, including dynamic and static block quantization” but this branch executes only when type == "dynamic". Tighten the comment to avoid confusion.
- # Double scale Block quantization, including dynamic and static block quantization + # Dynamic double-scale block quantization pathmodelopt/torch/quantization/plugins/diffusers.py (2)
117-132
: Align QKV mode and error message; avoid computing scales in export path
- Message “QKV QDQS must be in the same type” is unclear. Make it explicit: “Q, K, and V quantizers must all be dynamic or all be static.”
- Skip _get_amax when exporting; it’s unused at runtime and can be None for dynamic. Guard by torch.onnx.is_in_onnx_export().
- if ( + if ( not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic ): - q_quantized_scale = self.q_bmm_quantizer._get_amax(query) - k_quantized_scale = self.k_bmm_quantizer._get_amax(key) - v_quantized_scale = self.v_bmm_quantizer._get_amax(value) + if not torch.onnx.is_in_onnx_export(): + q_quantized_scale = self.q_bmm_quantizer._get_amax(query) + k_quantized_scale = self.k_bmm_quantizer._get_amax(key) + v_quantized_scale = self.v_bmm_quantizer._get_amax(value) + else: + q_quantized_scale = k_quantized_scale = v_quantized_scale = None else: assert ( self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic - ), "QKV QDQS must be in the same type" + ), "Q, K, and V quantizers must all be dynamic or all be static." q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None
133-137
: Using a private helper across modules; consider promoting to public APICalling _get_block_sizes_list from another module couples to a private method. Expose it as a public helper (e.g., get_block_sizes_list or a util function) to avoid brittle dependencies.
Would you like me to extract a small utility, e.g., modelopt/torch/quantization/utils/block_sizes.py:get_block_sizes_list(shape, block_sizes)?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/diffusers/quantization/config.py
(1 hunks)modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(5 hunks)modelopt/torch/quantization/plugins/diffusers.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/diffusers/quantization/config.py
- modelopt/torch/quantization/export_onnx.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
_get_amax
(540-549)_get_block_sizes_list
(651-672)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
🔇 Additional comments (2)
modelopt/torch/quantization/plugins/diffusers.py (2)
151-155
: Default disable flag may be surprisingYou pass True when _disable_fp8_mha is absent, which disables FP8 MHA by default. Confirm this is intended for evaluation builds, or flip default to False.
Would you like a config flag gate to avoid silent disablement in production?
251-269
: Verified — export_fp8_mha signature matches callsiteDefinition in modelopt/torch/quantization/export_onnx.py (def export_fp8_mha at ~line 420) includes q_block_shape, k_block_shape, v_block_shape; the call in modelopt/torch/quantization/plugins/diffusers.py passes them — no mismatch found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
703-716
: Fix scaled_e4m3 call-site argument orderscaled_e4m3 now expects block_sizes before bias; update all callers to the new signature.
- modelopt/torch/quantization/nn/modules/tensor_quantizer.py:707 (new usage)
- modelopt/torch/quantization/calib/histogram.py:311
- tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py:55
- tests/gpu/torch/quantization/test_tensor_quant_cuda.py:148, 158, 166, 173, 185, 187, 202
Use: scaled_e4m3(inputs, amax, block_sizes, bias, E, M, ...). If no block_sizes, pass None as the third argument and move bias to the fourth.
♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
959-963
: Guarantee cleanup of _original_input_shape on exceptions.Deletion is unconditional but not exception-safe; wrap the quantization region in try/finally. Also fix the typo “isnt” → “isn't”.
- if ( + cleanup_original_input_shape = False + if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): - # Reshape is required if the logic isnt handled in the simulation kernel + # Reshape is required if the logic isn't handled in the simulation kernel self._setup_for_blockquant(inputs) setattr(self, "_original_input_shape", inputs.shape) + cleanup_original_input_shape = True inputs = self._process_for_blockquant(inputs) - outputs = inputs + try: + outputs = inputs @@ - if ( + if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant - ): - outputs = self._reset_to_original_shape(outputs) - - if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") - return outputs + ): + outputs = self._reset_to_original_shape(outputs) + return outputs + finally: + if cleanup_original_input_shape and hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape")Also applies to: 1002-1003
🧹 Nitpick comments (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
651-673
: Type-safety and input validation for _get_block_sizes_list.Add explicit typing, validate keys, and guard length mismatches to avoid silently passing malformed shapes to downstream ONNX ops.
-def _get_block_sizes_list(self, shape): +from typing import Sequence + +def _get_block_sizes_list(self, shape: Sequence[int] | torch.Size) -> list[int] | None: @@ - block_sizes_list = [] - for dim in range(len(shape)): + # Only allow integer axes plus known metadata keys. + valid_meta = {"type", "scale_bits", "scale_block_sizes"} + assert all( + isinstance(k, int) or k in valid_meta for k in self.block_sizes.keys() + ), f"Invalid block_sizes keys: {list(self.block_sizes.keys())}" + + rank = len(shape) + block_sizes_list: list[int] = [] + for dim in range(rank): # Check both positive and negative dimension indices - dim_negative = dim - len(shape) + dim_negative = dim - rank block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) block_sizes_list.append(block_size if block_size is not None else 1) return block_sizes_listmodelopt/torch/quantization/export_onnx.py (2)
238-265
: Validate block_shape rank and surface clearer errors.Add a rank check before emitting TRT_DynamicQuantize; mis-sized block_shapes currently fall through to TRT with cryptic errors.
def _fp8_block_quantize( @@ - input_type = inputs.type().scalarType() + input_type = inputs.type().scalarType() + rank = symbolic_helper._get_tensor_rank(inputs) + assert rank is not None, "Input rank must be known at export time." + assert len(block_sizes) == rank, ( + f"block_shape length ({len(block_sizes)}) must match input rank ({rank})." + ) @@ quantized_output, scales_output = g.op( "trt::TRT_DynamicQuantize", inputs, block_shape_i=block_sizes,
503-509
: Block-shape consistency in FP8 MHA path.Validate q/k/v block shapes match input ranks; also ensure softmax path never receives a block shape.
- query_scaled = export_fp8( - g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape - ) + assert (q_block_shape is None) or ( + len(q_block_shape) == symbolic_helper._get_tensor_rank(query_scaled) + ), "q_block_shape rank mismatch." + query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape) @@ - key_transposed_scaled = export_fp8( - g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape - ) + assert (k_block_shape is None) or ( + len(k_block_shape) == symbolic_helper._get_tensor_rank(key_transposed_scaled) + ), "k_block_shape rank mismatch." + key_transposed_scaled = export_fp8(g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape) @@ - # We cannot do block quant for the softmax's output - attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None) + # We cannot do block quant for the softmax's output + attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None) @@ - value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape) + assert (v_block_shape is None) or ( + len(v_block_shape) == symbolic_helper._get_tensor_rank(value) + ), "v_block_shape rank mismatch." + value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)Also applies to: 535-549
modelopt/torch/quantization/plugins/diffusers.py (2)
117-132
: Q/K/V quantization mode must match: improve error and skip redundant work.The assertion is good. Minor: clarify message and avoid computing per-tensor amax if any quantizer is dynamic.
- ), "QKV QDQS must be in the same type" + ), "Q/K/V quantization modes must match: either all dynamic or all static."
133-137
: Guard block size list creation when block_sizes is None._if a quantizer has no block_sizes, _get_block_sizes_list returns None; that’s fine. Add a quick comment to make intent explicit and future-proof.
- q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] + # Returns None for per-tensor paths; ONNX export handles that by taking the non-block path. + q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/diffusers/quantization/config.py
(1 hunks)modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(5 hunks)modelopt/torch/quantization/plugins/diffusers.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/diffusers/quantization/config.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
trt_high_precision_dtype
(407-409)trt_high_precision_dtype
(412-413)block_sizes
(289-291)block_sizes
(294-296)amax
(233-238)amax
(241-252)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
_get_amax
(540-549)_get_block_sizes_list
(651-672)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/torch/quantization/plugins/diffusers.py (1)
229-231
: Keep 't' for q/k/v scales — export extracts constsexport_onnx.py already extracts constant floats (uses sym_help._get_const / _maybe_get_const for scale/amax), so the current parse_args ("..., 't','t','t', ...") is fine; only change those three to 'f' if the export_fp8 const-extraction fix is removed. Location: modelopt/torch/quantization/plugins/diffusers.py (symbolic decorator around lines 229–231; same check applies to lines ~241–249).
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please share a sample command for ONNX exporting for a supported model as well in the description.
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
649-671
: Return None for “metadata‑only” block_sizes and add typing.Avoid forcing block mode with [1,…,1] when there are no per‑dim entries; this can route ONNX path incorrectly.
Apply:
- def _get_block_sizes_list(self, shape): + def _get_block_sizes_list(self, shape) -> list[int] | None: """Convert block_sizes dict to list format based on tensor shape. @@ - if self.block_sizes is None: + if self.block_sizes is None: return None - - block_sizes_list = [] + # If there are no integer dimension entries with a meaningful block size, treat as no block quant. + has_dim_sizes = any( + isinstance(k, int) and (v is not None and v != 1) for k, v in self.block_sizes.items() + ) + if not has_dim_sizes: + return None + + block_sizes_list: list[int] = [] for dim in range(len(shape)): # Check both positive and negative dimension indices dim_negative = dim - len(shape) block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) block_sizes_list.append(block_size if block_size is not None else 1) return block_sizes_list
1005-1007
: Centralize deletion of _original_input_shape in a finally block.Move this into a single finally at the end of forward so it runs regardless of success/failure.
Apply:
- if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") - return outputs + try: + return outputs + finally: + if hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape")modelopt/torch/quantization/export_onnx.py (3)
294-321
: Remove brittle default [1,1,128,1] and validate rank in _fp8_block_dequantize.Defaulting silently is dangerous and shape-dependent. Require explicit block_sizes and assert correctness. (Same concern raised earlier.)
Apply:
-def _fp8_block_dequantize( +def _fp8_block_dequantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, scales: torch.Value, trt_high_precision_dtype: str, otype: str | None = None, - block_sizes: list = [1, 1, 128, 1], + block_sizes: list, ): """Helper Function for Dequantization.""" output_shape = sym_help._get_tensor_sizes(inputs) + # Validate block shape + rank = sym_help._get_tensor_rank(inputs) + assert rank is not None, "Input rank must be known at export time." + assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, ( + f"block_shape length ({len(block_sizes)}) must match input rank ({rank})." + ) + assert all(isinstance(b, int) and b > 0 for b in block_sizes), ( + "All entries in block_shape must be positive integers." + ) + if otype is None: + otype = inputs.type().scalarType()
323-345
: Handle non-Python amax safely and validate block_shapes before block Q/DQ path.float(amax) will break when amax is a graph Value/0‑dim tensor; also assert block_shapes align with input rank before calling block ops. (Echoing prior comment.)
Apply:
def export_fp8( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, - amax: float | None, + amax: float | None, trt_high_precision_dtype: str | None, block_sizes: list | None, ): """Export quantized model to FP8 ONNX.""" - scale = 1.0 if amax is None else 448.0 / float(amax) + if amax is None: + scale = 1.0 + else: + amax_const = sym_help._get_const(amax, "f", "amax") + # If not a constant at export time, fall back to neutral scale to avoid exporter errors. + scale = 1.0 if (amax_const is None or amax_const == 0) else 448.0 / float(amax_const) @@ - if not block_sizes: + if not block_sizes: q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype) return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype) else: + # Validate block shape early + rank = sym_help._get_tensor_rank(inputs) + assert rank is not None, "Input rank must be known at export time." + assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, ( + f"block_shape length ({len(block_sizes)}) must match input rank ({rank})." + ) + assert all(isinstance(b, int) and b > 0 for b in block_sizes), ( + "All entries in block_shape must be positive integers." + ) q_tensor, scales_output = _fp8_block_quantize( g, inputs, trt_high_precision_dtype, block_sizes ) return _fp8_block_dequantize( g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes )
238-265
: Validate block_shape against input rank and values in _fp8_block_quantize.Guard against mismatched ranks and non-positive entries to avoid invalid custom op attributes at export time.
Apply:
def _fp8_block_quantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, trt_high_precision_dtype: str, block_sizes: list, ): """Helper Function for Quantization.""" output_shape = sym_help._get_tensor_sizes(inputs) + # Validate block shape + rank = sym_help._get_tensor_rank(inputs) + assert rank is not None, "Input rank must be known at export time." + assert isinstance(block_sizes, (list, tuple)) and len(block_sizes) == rank, ( + f"block_shape length ({len(block_sizes)}) must match input rank ({rank})." + ) + assert all(isinstance(b, int) and b > 0 for b in block_sizes), ( + "All entries in block_shape must be positive integers." + )
🧹 Nitpick comments (2)
modelopt/torch/quantization/export_onnx.py (1)
512-518
: Pre-validate q/k block_shapes vs tensor ranks to fail fast.Catch mismatches early instead of deep inside TRT ops.
Apply:
- query_scaled = export_fp8( + # Sanity-check block shapes + for name, t, bs in (("q", query_scaled, q_block_shape), ("k", key_transposed_scaled, k_block_shape)): + if bs is not None: + r = sym_help._get_tensor_rank(t) + assert r is not None and len(bs) == r, f"{name}_block_shape must match rank ({r})." + query_scaled = export_fp8( g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape ) @@ - key_transposed_scaled = export_fp8( + key_transposed_scaled = export_fp8( g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape )modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
962-965
: Guarantee cleanup of _original_input_shape via try/finally; fix typo.Ensure attribute is deleted even on exceptions; also fix “isnt” -> “isn't”.
Apply:
- # Reshape is required if the logic isnt handled in the simulation kernel - self._setup_for_blockquant(inputs) - setattr(self, "_original_input_shape", inputs.shape) - inputs = self._process_for_blockquant(inputs) + # Reshape is required if the logic isn't handled in the simulation kernel + cleanup_original_shape = False + try: + self._setup_for_blockquant(inputs) + setattr(self, "_original_input_shape", inputs.shape) + cleanup_original_shape = True + inputs = self._process_for_blockquant(inputs) + except Exception: + # Make sure we don't leak transient attributes on failure + if cleanup_original_shape and hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape") + raise
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
trt_high_precision_dtype
(405-407)trt_high_precision_dtype
(410-411)block_sizes
(287-289)block_sizes
(292-294)amax
(231-236)amax
(239-250)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/quantization/export_onnx.py (2)
545-549
: Softmax stays non-block — good call.Keeping softmax output per-tensor avoids mismatched block semantics across BMM2.
558-559
: Block‑quantizing V path looks correct.The cast order aligns with the earlier Q/K path.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
701-714
: Pass block_sizes to scaled_e4m3 by keyword to avoid positional drift.rg found only call sites (tests and histogram.py); I couldn't locate the scaled_e4m3 definition—confirm it accepts a block_sizes kwarg. If it does, apply:
outputs = scaled_e4m3( inputs, amax, self._get_bias(inputs), E, M, self._trt_high_precision_dtype, self._pass_through_bwd, - block_sizes_list, + block_sizes=block_sizes_list, )
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
649-671
: Don’t force “block mode” when block_sizes has only metadata; return None.Current logic returns [1, ...] even if block_sizes has no per‑dim sizes, unintentionally enabling block handling downstream. Return None when there are no integer keys with v not in {None, 1}. Add typing while here.
- def _get_block_sizes_list(self, shape): + def _get_block_sizes_list(self, shape) -> list[int] | None: @@ - if self.block_sizes is None: - return None - - block_sizes_list = [] + if self.block_sizes is None: + return None + # If there are no per-dimension entries with a real block size, treat as no block quant. + has_dim_sizes = any( + isinstance(k, int) and (v is not None and v != 1) + for k, v in self.block_sizes.items() + ) + if not has_dim_sizes: + return None + + block_sizes_list: list[int] = [] for dim in range(len(shape)): # Check both positive and negative dimension indices dim_negative = dim - len(shape) block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) block_sizes_list.append(block_size if block_size is not None else 1) return block_sizes_list
1006-1007
: Guarantee deletion of _original_input_shape with try/finally (prev feedback).Move deletion into a finally block guarded by the local flag so it always runs.
- if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") - return outputs + try: + return outputs + finally: + if cleanup_original_input_shape and hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape")
🧹 Nitpick comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
962-966
: Make _original_input_shape cleanup exception‑safe: set flag when set, clean in finally.Set a local flag when creating the attribute so the finally block can reliably clean it even if an exception occurs later.
if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): # Reshape is required if the logic is not handled in the simulation kernel # Only MX format and NVFP4 reshape are currently supported by the kernel. self._setup_for_blockquant(inputs) - setattr(self, "_original_input_shape", inputs.shape) + setattr(self, "_original_input_shape", inputs.shape) + cleanup_original_input_shape = True inputs = self._process_for_blockquant(inputs)Add the flag near the top of forward (before this block):
- # Rotating the input + cleanup_original_input_shape = False + # Rotating the input
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
701-704
: LGTM: using pre‑reshape shape to derive per‑dim blocks.Using _original_input_shape avoids mismapping after reshape/flatten.
705-714
: Pass block_sizes as a keyword; confirm scaled_e4m3 signature and update callersChange this call to pass block_sizes by name to avoid positional-argument drift; before merging, confirm the scaled_e4m3 definition accepts a named block_sizes parameter (or update all callers if signature changed). Location: modelopt/torch/quantization/nn/modules/tensor_quantizer.py (around lines 705–714).
outputs = scaled_e4m3( inputs, amax, self._get_bias(inputs), E, M, self._trt_high_precision_dtype, self._pass_through_bwd, - block_sizes_list, + block_sizes=block_sizes_list, )Quick verification commands to run locally:
- rg -nP 'def\s+scaled_e4m3\s*(' -C2
- rg -nP '\bscaled_e4m3\s*(' -C2
- rg -nP '\bscaled_e4m3\s*([^)]block_sizes\s=' -n
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
674-687
: scaled_e4m3 signature change breaks existing callers—use keyword arguments.You inserted
block_sizes_list
as the 3rd positional argument toscaled_e4m3
, but other call sites (e.g.,modelopt/torch/quantization/calib/histogram.py:311
and test files) still passbias
,E
,M
positionally. This will cause argument-position mismatches and break tests and ONNX export.Call
scaled_e4m3
with keyword arguments to avoid positional-shift bugs.# Convert block_sizes dict to list format # Use original input shape if available (before reshaping), otherwise use current shape shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape) block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes) - outputs = scaled_e4m3( - inputs, - amax, - self._get_bias(inputs), - E, - M, - self._trt_high_precision_dtype, - self._pass_through_bwd, - block_sizes_list, - ) + outputs = scaled_e4m3( + inputs=inputs, + amax=amax, + bias=self._get_bias(inputs), + E=E, + M=M, + trt_high_precision_dtype=self._trt_high_precision_dtype, + pass_through_bwd=self._pass_through_bwd, + block_sizes=block_sizes_list, + )Based on learnings
modelopt/torch/quantization/plugins/diffusers.py (1)
253-271
: Normalize optional block-shapes to Python int-lists for ONNX attrs.When parse_args passes 'v', convert Values to lists; keep None when absent.
def symbolic( g: "GraphContext", query: "torch._C.Value", @@ - ): - """Symbolic method.""" + ): + """Symbolic method.""" + # Normalize optional block-shapes to Python int lists for ONNX attributes + if q_block_shape is not None: + q_block_shape = symbolic_helper._maybe_get_const(q_block_shape, "is") + if k_block_shape is not None: + k_block_shape = symbolic_helper._maybe_get_const(k_block_shape, "is") + if v_block_shape is not None: + v_block_shape = symbolic_helper._maybe_get_const(v_block_shape, "is") return export_fp8_mha( g, query, key, value,
♻️ Duplicate comments (5)
modelopt/torch/quantization/export_onnx.py (2)
299-325
: Remove brittle default block_sizes=[1,1,128,1] and validate rank.A hardcoded default of
[1, 1, 128, 1]
is unsafe—it assumes a specific tensor rank and shape. Require callers to passblock_sizes
explicitly and validate it against the input tensor rank.def _fp8_block_dequantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, scales: torch.Value, trt_high_precision_dtype: str, otype: str | None = None, - block_sizes: list = [1, 1, 128, 1], + block_sizes: list | None = None, ): """Helper Function for Dequantization.""" + assert block_sizes is not None, "block_sizes must be provided for block dequantization." output_shape = sym_help._get_tensor_sizes(inputs) + assert len(block_sizes) == len(output_shape), ( + f"block_sizes length ({len(block_sizes)}) must match input rank ({len(output_shape)})." + ) assert trt_high_precision_dtype in (otype, "Float"), ( "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." ) out = g.op( "trt::TRT_BlockDequantize", inputs, scales, block_shape_i=block_sizes, ).setType( inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape) ) # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT # custom ops, so cast the output if needed. if trt_high_precision_dtype != otype: out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index] return outBased on learnings
328-349
: Handle tensor/Value amax correctly during symbolic export.
float(amax)
on line 336 will fail ifamax
is a graph Value or 0-dim tensor. Use_get_const
to extract Python floats when possible, or handle symbolic amax.def export_fp8( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, amax: float | None, trt_high_precision_dtype: str | None, block_sizes: list | None, ): """Export quantized model to FP8 ONNX.""" - scale = 1.0 if amax is None else 448.0 / float(amax) + # Accept Python floats or Tensor constants; None indicates dynamic/block path not using amax. + if amax is None: + scale = 1.0 + else: + amax_const = sym_help._get_const(amax, "f", "amax") if hasattr(amax, "node") else amax + # If still not constant, fall back to 1.0 to avoid build-time failures. + scale = 1.0 if amax_const is None else 448.0 / float(amax_const) otype = inputs.type().scalarType() if trt_high_precision_dtype is None: trt_high_precision_dtype = otype if not block_sizes: q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype) return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype) else: q_tensor, scales_output = _fp8_block_quantize( g, inputs, trt_high_precision_dtype, block_sizes ) return _fp8_block_dequantize( g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes )Based on learnings
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
622-643
: Guard against metadata-only block_sizes dicts returning [1,1,1,1].If
block_sizes
contains only non-integer keys like{"type": "dynamic", "scale_bits": (4,3)}
, you return[1]*rank
, which signals "block mode" downstream even though no dimension has an actual block size. This can incorrectly trigger the FP8 block path.Return
None
when there are no integer dimension keys with a block size > 1.def _get_block_sizes_list(self, shape): """Convert block_sizes dict to list format based on tensor shape. Args: shape: The tensor shape to use for conversion (can be tuple or torch.Size) Returns: List of block sizes for each dimension, or None if block_sizes is None Example: block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1] """ if self.block_sizes is None: return None + + # Check if there are any actual per-dimension block sizes (not just metadata) + has_dim_sizes = any( + isinstance(k, int) and v not in (None, 1) for k, v in self.block_sizes.items() + ) + if not has_dim_sizes: + return None block_sizes_list = [] for dim in range(len(shape)): # Check both positive and negative dimension indices dim_negative = dim - len(shape) block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) block_sizes_list.append(block_size if block_size is not None else 1) return block_sizes_listBased on learnings
975-976
: Cleanup of _original_input_shape should be in a try-finally block.The deletion of
_original_input_shape
at the end of forward is not guaranteed if an exception occurs during the forward pass. This can leave stale state that corrupts subsequent calls.Wrap the forward logic that depends on
_original_input_shape
in a try-finally block to guarantee cleanup.if ( self.block_sizes is not None and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): # Reshape is required if the logic is not handled in the simulation kernel # Only MX format and NVFP4 reshape are currently supported by the kernel. self._setup_for_blockquant(inputs) setattr(self, "_original_input_shape", inputs.shape) inputs = self._process_for_blockquant(inputs) - outputs = inputs - - block_size = None - if self._if_calib and not self._dynamic: - if self._calibrator is None: - raise RuntimeError("Calibrator was not created.") - # Shape is only known when it sees the first tensor - if self.block_sizes is not None and self.block_sizes.get("type", None) == "dynamic": - block_size = self.block_sizes.get(-1, None) or self.block_sizes.get( - inputs.dim() - 1, None - ) - assert block_size is not None, "block size for dynamic quantization not found." - - self.collect(inputs) - - if self._if_quant: - # Check if the input tensor is contiguous - # Non-contiguous tensors will generate incorrect FP4 quantization results - if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): - inputs.data = inputs.data.contiguous() - if self.fake_quant: - outputs = self._fake_quantize(inputs) - elif not self._dequantize: - outputs = self._real_quantize(inputs) - else: - raise ValueError( - "self._dequantize is True and self.fake_quant is False. " - "This case should have been handled." - ) - - if ( - self.block_sizes is not None - and self.block_sizes.get("type", None) != "dynamic" - and self._fake_quant - ): - outputs = self._reset_to_original_shape(outputs) - - if hasattr(self, "_original_input_shape"): - delattr(self, "_original_input_shape") + try: + outputs = inputs + + block_size = None + if self._if_calib and not self._dynamic: + if self._calibrator is None: + raise RuntimeError("Calibrator was not created.") + # Shape is only known when it sees the first tensor + if self.block_sizes is not None and self.block_sizes.get("type", None) == "dynamic": + block_size = self.block_sizes.get(-1, None) or self.block_sizes.get( + inputs.dim() - 1, None + ) + assert block_size is not None, "block size for dynamic quantization not found." + + self.collect(inputs) + + if self._if_quant: + # Check if the input tensor is contiguous + # Non-contiguous tensors will generate incorrect FP4 quantization results + if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): + inputs.data = inputs.data.contiguous() + if self.fake_quant: + outputs = self._fake_quantize(inputs) + elif not self._dequantize: + outputs = self._real_quantize(inputs) + else: + raise ValueError( + "self._dequantize is True and self.fake_quant is False. " + "This case should have been handled." + ) + + if ( + self.block_sizes is not None + and self.block_sizes.get("type", None) != "dynamic" + and self._fake_quant + ): + outputs = self._reset_to_original_shape(outputs) + finally: + if hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape") + return outputsBased on learnings
modelopt/torch/quantization/plugins/diffusers.py (1)
231-235
: Fix ONNX parse_args: allow optional scales and block-shapes (prevents export crash).
- q/k/v scales can be None for dynamic quantization; using 't' rejects None.
- q/k/v block-shapes are optional; using 'is' rejects None.
Update parse_args to accept optionals with 'v' for these 6 entries.
- @symbolic_helper.parse_args( - "v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is" - ) + @symbolic_helper.parse_args( + "v", "v", "v", "v", "f", "b", "v", "v", "v", "v", "s", "b", "v", "v", "v" + )Also normalize the 3 block-shape values to Python lists inside symbolic before forwarding (see next comment). Based on learnings.
Also applies to: 243-251
🧹 Nitpick comments (4)
modelopt/torch/quantization/plugins/diffusers.py (4)
120-135
: Clarify assertion message for mixed dynamic/static Q/K/V quantizers.Message is cryptic. Make intent explicit.
- ), "QKV QDQS must be in the same type" + ), "q/k/v quantizers must all be either dynamic or static"
243-251
: Adjust symbolic type hints to reflect actual tensor/value types.Scales are Values/Tensors, not floats; block-shapes are list[int] | None.
- scale: torch._C.Value | None = None, - q_quantized_scale: float | None = 1.0, - k_quantized_scale: float | None = 1.0, - v_quantized_scale: float | None = 1.0, + scale: "torch._C.Value | None" = None, + q_quantized_scale: "torch._C.Value | None" = None, + k_quantized_scale: "torch._C.Value | None" = None, + v_quantized_scale: "torch._C.Value | None" = None, high_precision_flag: str = "Half", disable_fp8_mha: bool = True, - q_block_shape: list | None = None, - k_block_shape: list | None = None, - v_block_shape: list | None = None, + q_block_shape: "list[int] | None" = None, + k_block_shape: "list[int] | None" = None, + v_block_shape: "list[int] | None" = None,
136-139
: Block-shape None path is now safe once parse_args is fixed; optional naming nit.With the parse_args/symbolic normalization above, passing None here is OK. Consider renaming variables to *_block_shape for consistency with FP8SDPA.forward.
- q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] - k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr] - v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr] + q_block_shape = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] + k_block_shape = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr] + v_block_shape = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr]Remember to update the call site accordingly (Lines 155-157).
155-157
: *If you rename to _block_shape, update the call site accordingly.- q_block_sizes, - k_block_sizes, - v_block_sizes, + q_block_shape, + k_block_shape, + v_block_shape,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/export_onnx.py
(6 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(5 hunks)modelopt/torch/quantization/plugins/diffusers.py
(6 hunks)modelopt/torch/quantization/tensor_quant.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
block_sizes
(293-295)block_sizes
(298-300)amax
(236-241)amax
(244-255)trt_high_precision_dtype
(411-413)trt_high_precision_dtype
(416-417)modelopt/torch/quantization/export_onnx.py (1)
export_fp8
(328-349)
modelopt/torch/quantization/export_onnx.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
trt_high_precision_dtype
(411-413)trt_high_precision_dtype
(416-417)block_sizes
(293-295)block_sizes
(298-300)amax
(236-241)amax
(244-255)
modelopt/torch/quantization/plugins/diffusers.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
_get_amax
(496-505)_get_block_sizes_list
(622-643)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/quantization/tensor_quant.py (3)
409-424
: LGTM—ScaledE4M3Function signature correctly extended for block_sizes.The symbolic and forward signatures are updated consistently:
parse_args
extended to("v", "t", "t", "is", "i", "i", "s", "b")
to acceptblock_sizes
as int or sequence.export_fp8
call correctly passesblock_sizes
as the 4th argument.- The change aligns with the broader PR goal of propagating block sizes through the FP8 export path.
437-437
: LGTM—forward signature extended with block_sizes parameter.The
block_sizes=None
parameter is appended to the end of the signature, maintaining backward compatibility for callers that do not pass it. This aligns with the symbolic signature update.
465-465
: LGTM—backward num_args updated to 8 to match new signature.The
num_args=8
correctly accounts for the additionalblock_sizes
parameter added to the forward signature. This ensures the backward pass handles all arguments properly.modelopt/torch/quantization/export_onnx.py (5)
243-269
: LGTM—_fp8_block_quantize correctly implements block quantization.The helper correctly:
- Casts inputs to the target dtype if needed.
- Calls
trt::TRT_DynamicQuantize
withblock_shape_i=block_sizes
.- Returns both quantized output and scales.
- Sets the output type to
uint8
as expected.
444-452
: LGTM—export_fp8_mha signature correctly extended for block shapes.The signature updates add:
q_quantized_scale
,k_quantized_scale
,v_quantized_scale
asfloat | None
.q_block_shape
,k_block_shape
,v_block_shape
aslist | None
.This aligns with the broader PR goal of propagating per-tensor block shapes through the FP8-MHA export path.
529-534
: LGTM—export_fp8 called with q_block_shape for query quantization.The call correctly passes
q_block_shape
toexport_fp8
, enabling block quantization for the query tensor in the FP8-MHA path.
562-563
: LGTM—softmax output correctly uses None for block_shape.The comment "We cannot do block quant for the softmax's output" is accurate, and passing
None
as the block_shape ensures the softmax output is not block-quantized.
573-573
: LGTM—export_fp8 called with v_block_shape for value quantization.The call correctly passes
v_block_shape
toexport_fp8
, enabling block quantization for the value tensor in the FP8-MHA path.modelopt/torch/quantization/plugins/diffusers.py (1)
268-270
: Forwarding of block-shapes to export looks correct.No issues forwarding q/k/v block-shapes once normalized.
# Reshape is required if the logic is not handled in the simulation kernel | ||
# Only MX format and NVFP4 reshape are currently supported by the kernel. | ||
self._setup_for_blockquant(inputs) | ||
setattr(self, "_original_input_shape", inputs.shape) | ||
inputs = self._process_for_blockquant(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Storing _original_input_shape as an attribute creates state leakage risks.
Setting _original_input_shape
on self
in the forward pass can cause issues if an exception is raised before cleanup (line 976) or if multiple forward calls interleave (e.g., during model export or distributed training). The attribute can linger and pollute subsequent forward passes.
Consider using a local variable or a context manager to ensure cleanup even on exceptions.
if (
self.block_sizes is not None
and self.block_sizes.get("type", None) != "dynamic"
and self._fake_quant
):
# Reshape is required if the logic is not handled in the simulation kernel
# Only MX format and NVFP4 reshape are currently supported by the kernel.
self._setup_for_blockquant(inputs)
- setattr(self, "_original_input_shape", inputs.shape)
+ original_input_shape = inputs.shape
+ self._original_input_shape = original_input_shape
inputs = self._process_for_blockquant(inputs)
Then wrap the forward logic in try-finally as noted in the next comment.
Based on learnings
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines 932
to 936, avoid assigning self._original_input_shape inside forward since it leaks
state across calls and during exceptions; instead store the original input shape
in a local variable (e.g., original_input_shape = inputs.shape) or use a
short-lived context manager, and wrap the forward transformation/processing
logic in a try-finally block so any necessary cleanup (previously at line 976)
always runs even if an exception occurs; remove the attribute assignment and
update subsequent references to use the local variable or the context-managed
value.
@staticmethod | ||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") | ||
@symbolic_helper.parse_args( | ||
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is" | ||
) | ||
def symbolic( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Repo-wide check: ensure no other symbolics still use 't' for optional scales.
Run to verify and avoid future breakages.
🏁 Script executed:
#!/bin/bash
# Find symbolics that parse optional scales with 't' (should be 'v')
rg -nP '@symbolic_helper\.parse_args\([^)]*\)' -C2 | rg -n 'parse_args\((.*)\)' -o | sed -E 's/.*parse_args\((.*)\).*/\1/' | awk 'BEGIN{RS=","} /"t"/{print FILENAME ":" NR ": " $0}' FILENAME=/dev/stdin
# Narrow search for q/k/v scale params in this file
rg -n 'parse_args\(' modelopt/torch/quantization/plugins/diffusers.py -n -C1
Length of output: 200
🏁 Script executed:
#!/bin/bash
# List all @symbolic_helper.parse_args calls containing "t" across Python files
rg -nP '@symbolic_helper\.parse_args\([^)]*\"t\"[^)]*\)' -g '*.py' -C1
Length of output: 779
🏁 Script executed:
#!/bin/bash
# Show context around parse_args usages in tensor_quant.py
sed -n '320,335p' modelopt/torch/quantization/tensor_quant.py
sed -n '405,420p' modelopt/torch/quantization/tensor_quant.py
sed -n '500,515p' modelopt/torch/quantization/tensor_quant.py
Length of output: 1170
Replace 't' with 'v' for optional scale parameters in all symbolic definitions
- modelopt/torch/quantization/tensor_quant.py: FakeTensorQuantFunction.symbolic (parse_args at line 326)
- modelopt/torch/quantization/tensor_quant.py: ScaledE4M3Function.symbolic (parse_args at line 409)
- modelopt/torch/quantization/tensor_quant.py: DynamicBlockQuantizationFunction.symbolic (parse_args at line 506)
What does this PR do?
Type of change: new feature
Overview:
Usage
Testing
evaluation
feature since the TRT kernel isn’t ready. No test cases are required at this time, we will add the test case next month.Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements