Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@
"algorithm": "max",
}

FP8_SAGE_DEFAULT_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
"*output_quantizer": {"enable": False},
"*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}},
"*softmax_quantizer": {
"num_bits": (4, 3),
"axis": None,
},
"default": {"enable": False},
},
"algorithm": "max",
}

INT8_DEFAULT_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 8, "axis": 0},
Expand Down
2 changes: 1 addition & 1 deletion examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def forward_loop(mod):
backbone,
model_config.model_type,
quant_config.format,
quantize_mha=QuantizationConfig.quantize_mha,
quantize_mha=quant_config.quantize_mha,
)
logger.info("Quantization process completed successfully!")

Expand Down
100 changes: 86 additions & 14 deletions modelopt/torch/quantization/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,37 @@ def _fp8_quantize(
return q_op


def _fp8_block_quantize(
g: torch.onnx._internal.jit_utils.GraphContext,
inputs: torch.Value,
trt_high_precision_dtype: str,
block_sizes: list,
):
"""Helper Function for Quantization."""
output_shape = sym_help._get_tensor_sizes(inputs)

# TRT StronglyType only supports FP16 QDQs
# custom ops, so cast the input if needed.
input_type = inputs.type().scalarType()
assert trt_high_precision_dtype in (input_type, "Float"), (
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
)
if trt_high_precision_dtype != input_type:
inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype])
quantized_output, scales_output = g.op(
"trt::TRT_DynamicQuantize",
inputs,
block_shape_i=block_sizes,
outputs=2,
scale_type_i=onnx_dtype_map["Float"],
output_type_i=onnx_dtype_map["Float8"],
)
quantized_output.setType(inputs.type().with_dtype(torch.uint8).with_sizes(output_shape))
return quantized_output, scales_output


def _fp8_dequantize(
g: "GraphContext",
g: torch.onnx._internal.jit_utils.GraphContext,
inputs: torch.Value,
scale_inv: float,
trt_high_precision_dtype: str,
Expand All @@ -267,20 +296,57 @@ def _fp8_dequantize(
return out


def _fp8_block_dequantize(
g: torch.onnx._internal.jit_utils.GraphContext,
inputs: torch.Value,
scales: torch.Value,
trt_high_precision_dtype: str,
otype: str | None = None,
block_sizes: list = [1, 1, 128, 1],
):
"""Helper Function for Dequantization."""
output_shape = sym_help._get_tensor_sizes(inputs)
assert trt_high_precision_dtype in (otype, "Float"), (
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
)
out = g.op(
"trt::TRT_BlockDequantize",
inputs,
scales,
block_shape_i=block_sizes,
).setType(
inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape)
)

# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the output if needed.
if trt_high_precision_dtype != otype:
out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index]
return out


def export_fp8(
g: "GraphContext",
g: torch.onnx._internal.jit_utils.GraphContext,
inputs: torch.Value,
amax: float,
amax: float | None,
trt_high_precision_dtype: str | None,
block_sizes: list | None,
):
"""Export quantized model to FP8 ONNX."""
scale = 1.0 if amax is None else 448.0 / float(amax)
otype = inputs.type().scalarType()
if trt_high_precision_dtype is None:
trt_high_precision_dtype = otype

q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
if not block_sizes:
q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
else:
q_tensor, scales_output = _fp8_block_quantize(
g, inputs, trt_high_precision_dtype, block_sizes
)
return _fp8_block_dequantize(
g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes
)


def scaled_dot_product_attention(
Expand Down Expand Up @@ -375,12 +441,15 @@ def export_fp8_mha(
attn_mask: "torch._C.Value | None" = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: "torch._C.Value | None" = None,
q_quantized_scale: float = 1.0,
k_quantized_scale: float = 1.0,
v_quantized_scale: float = 1.0,
scale: torch._C.Value | None = None,
q_quantized_scale: float | None = 1.0,
k_quantized_scale: float | None = 1.0,
v_quantized_scale: float | None = 1.0,
high_precision_flag: str = "Half",
disable_fp8_mha: bool = True,
q_block_shape: list | None = None,
k_block_shape: list | None = None,
v_block_shape: list | None = None,
):
r"""Export quantized fMHA to FP8 ONNX.

Expand Down Expand Up @@ -457,10 +526,12 @@ def export_fp8_mha(
v_input_dtype = value.type().scalarType()
if {q_input_dtype, k_input_dtype, v_input_dtype} != {high_precision_flag}:
raise ValueError("The quantized MHA must have 16-bit inputs.")
query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag)
query_scaled = export_fp8(
g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape
)
query_scaled = g.op("Cast", query_scaled, to_i=onnx_dtype_map["Float"])
key_transposed_scaled = export_fp8(
g, key_transposed_scaled, k_quantized_scale, high_precision_flag
g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape
)
key_transposed_scaled = g.op("Cast", key_transposed_scaled, to_i=onnx_dtype_map["Float"])
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
Expand Down Expand Up @@ -488,7 +559,8 @@ def export_fp8_mha(

if not disable_fp8_mha:
# Softmax's output scale is hard coded to 1.0
attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag)
# We cannot do block quant for the softmax's output
attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None)
attn_weight = g.op("Cast", attn_weight, to_i=onnx_dtype_map["Float"])

if dropout_p != 0:
Expand All @@ -498,7 +570,7 @@ def export_fp8_mha(
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
)
if not disable_fp8_mha:
value = export_fp8(g, value, v_quantized_scale, high_precision_flag)
value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)
value = g.op("Cast", value, to_i=onnx_dtype_map["Float"])
return g.op(
"Cast",
Expand Down
37 changes: 34 additions & 3 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,14 +619,37 @@ def _real_quantize(self, inputs):
self._dequantize = True
return outputs

def _get_block_sizes_list(self, shape):
"""Convert block_sizes dict to list format based on tensor shape.

Args:
shape: The tensor shape to use for conversion (can be tuple or torch.Size)

Returns:
List of block sizes for each dimension, or None if block_sizes is None

Example:
block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1]
"""
if self.block_sizes is None:
return None

block_sizes_list = []
for dim in range(len(shape)):
# Check both positive and negative dimension indices
dim_negative = dim - len(shape)
block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
block_sizes_list.append(block_size if block_size is not None else 1)
return block_sizes_list

def _fake_quantize(self, inputs):
"""Fake quantization."""
amax = None
if not self.is_mx_format:
amax = self._get_amax(inputs)

if self.block_sizes is not None and self.block_sizes.get("type", "static") == "dynamic":
# Block quantization, including dynamic and static block quantization
# Double scale Block quantization, including dynamic and static block quantization
block_size = self.block_sizes.get(-1, None) or self.block_sizes.get(
inputs.dim() - 1, None
)
Expand All @@ -648,6 +671,10 @@ def _fake_quantize(self, inputs):
# Float-point quantization, e.g., FP8
E, M = self._num_bits # noqa: N806

# Convert block_sizes dict to list format
# Use original input shape if available (before reshaping), otherwise use current shape
shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape)
block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes)
outputs = scaled_e4m3(
inputs,
amax,
Expand All @@ -656,6 +683,7 @@ def _fake_quantize(self, inputs):
M,
self._trt_high_precision_dtype,
self._pass_through_bwd,
block_sizes_list,
)

else:
Expand Down Expand Up @@ -901,9 +929,10 @@ def forward(self, inputs):
and self.block_sizes.get("type", None) != "dynamic"
and self._fake_quant
):
# Tensor reshaping is required for static block quantization
# Tensor shapes are handled separately by the quantization kernels for dynamic block quantization
# Reshape is required if the logic is not handled in the simulation kernel
# Only MX format and NVFP4 reshape are currently supported by the kernel.
self._setup_for_blockquant(inputs)
setattr(self, "_original_input_shape", inputs.shape)
inputs = self._process_for_blockquant(inputs)
Comment on lines +932 to 936
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Storing _original_input_shape as an attribute creates state leakage risks.

Setting _original_input_shape on self in the forward pass can cause issues if an exception is raised before cleanup (line 976) or if multiple forward calls interleave (e.g., during model export or distributed training). The attribute can linger and pollute subsequent forward passes.

Consider using a local variable or a context manager to ensure cleanup even on exceptions.

         if (
             self.block_sizes is not None
             and self.block_sizes.get("type", None) != "dynamic"
             and self._fake_quant
         ):
             # Reshape is required if the logic is not handled in the simulation kernel
             # Only MX format and NVFP4 reshape are currently supported by the kernel.
             self._setup_for_blockquant(inputs)
-            setattr(self, "_original_input_shape", inputs.shape)
+            original_input_shape = inputs.shape
+            self._original_input_shape = original_input_shape
             inputs = self._process_for_blockquant(inputs)

Then wrap the forward logic in try-finally as noted in the next comment.

Based on learnings

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines 932
to 936, avoid assigning self._original_input_shape inside forward since it leaks
state across calls and during exceptions; instead store the original input shape
in a local variable (e.g., original_input_shape = inputs.shape) or use a
short-lived context manager, and wrap the forward transformation/processing
logic in a try-finally block so any necessary cleanup (previously at line 976)
always runs even if an exception occurs; remove the attribute assignment and
update subsequent references to use the local variable or the context-managed
value.


outputs = inputs
Expand Down Expand Up @@ -943,6 +972,8 @@ def forward(self, inputs):
):
outputs = self._reset_to_original_shape(outputs)

if hasattr(self, "_original_input_shape"):
delattr(self, "_original_input_shape")
return outputs

def _short_amax(self, fmt=".4f"):
Expand Down
47 changes: 39 additions & 8 deletions modelopt/torch/quantization/plugins/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,26 @@ def _quantized_sdpa(self, *args, **kwargs):
key = self.k_bmm_quantizer(key)
value = self.v_bmm_quantizer(value)

q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
if (
not self.q_bmm_quantizer._dynamic
and not self.k_bmm_quantizer._dynamic
and not self.v_bmm_quantizer._dynamic
):
q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
else:
assert (
self.q_bmm_quantizer._dynamic
and self.k_bmm_quantizer._dynamic
and self.v_bmm_quantizer._dynamic
), "QKV QDQS must be in the same type"
q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None

# Get block sizes lists for each quantizer if needed
q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr]
k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr]
v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr]

# We don't need to calibrate the output of softmax
return self.bmm2_output_quantizer(
Expand All @@ -135,6 +152,9 @@ def _quantized_sdpa(self, *args, **kwargs):
if hasattr(self.q_bmm_quantizer, "trt_high_precision_dtype")
else "Half",
self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True,
q_block_sizes,
k_block_sizes,
v_block_sizes,
)
)

Expand Down Expand Up @@ -188,6 +208,9 @@ def forward(
v_quantized_scale=None,
high_precision_flag=None,
disable_fp8_mha=True,
q_block_shape: list | None = None,
k_block_shape: list | None = None,
v_block_shape: list | None = None,
):
"""Forward method."""
ctx.save_for_backward(query, key, value, attn_mask)
Expand All @@ -206,7 +229,9 @@ def forward(
)

@staticmethod
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b")
@symbolic_helper.parse_args(
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is"
)
def symbolic(
Comment on lines 231 to 235
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Repo-wide check: ensure no other symbolics still use 't' for optional scales.

Run to verify and avoid future breakages.


🏁 Script executed:

#!/bin/bash
# Find symbolics that parse optional scales with 't' (should be 'v')
rg -nP '@symbolic_helper\.parse_args\([^)]*\)' -C2 | rg -n 'parse_args\((.*)\)' -o | sed -E 's/.*parse_args\((.*)\).*/\1/' | awk 'BEGIN{RS=","} /"t"/{print FILENAME ":" NR ": " $0}' FILENAME=/dev/stdin
# Narrow search for q/k/v scale params in this file
rg -n 'parse_args\(' modelopt/torch/quantization/plugins/diffusers.py -n -C1

Length of output: 200


🏁 Script executed:

#!/bin/bash
# List all @symbolic_helper.parse_args calls containing "t" across Python files
rg -nP '@symbolic_helper\.parse_args\([^)]*\"t\"[^)]*\)' -g '*.py' -C1

Length of output: 779


🏁 Script executed:

#!/bin/bash
# Show context around parse_args usages in tensor_quant.py
sed -n '320,335p' modelopt/torch/quantization/tensor_quant.py
sed -n '405,420p' modelopt/torch/quantization/tensor_quant.py
sed -n '500,515p' modelopt/torch/quantization/tensor_quant.py

Length of output: 1170


Replace 't' with 'v' for optional scale parameters in all symbolic definitions

  • modelopt/torch/quantization/tensor_quant.py: FakeTensorQuantFunction.symbolic (parse_args at line 326)
  • modelopt/torch/quantization/tensor_quant.py: ScaledE4M3Function.symbolic (parse_args at line 409)
  • modelopt/torch/quantization/tensor_quant.py: DynamicBlockQuantizationFunction.symbolic (parse_args at line 506)

g: "GraphContext",
query: "torch._C.Value",
Expand All @@ -215,12 +240,15 @@ def symbolic(
attn_mask: "torch._C.Value | None" = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: "torch._C.Value | None" = None,
q_quantized_scale: float = 1.0,
k_quantized_scale: float = 1.0,
v_quantized_scale: float = 1.0,
scale: torch._C.Value | None = None,
q_quantized_scale: float | None = 1.0,
k_quantized_scale: float | None = 1.0,
v_quantized_scale: float | None = 1.0,
high_precision_flag: str = "Half",
disable_fp8_mha: bool = True,
q_block_shape: list | None = None,
k_block_shape: list | None = None,
v_block_shape: list | None = None,
):
"""Symbolic method."""
return export_fp8_mha(
Expand All @@ -237,4 +265,7 @@ def symbolic(
v_quantized_scale,
high_precision_flag,
disable_fp8_mha,
q_block_shape,
k_block_shape,
v_block_shape,
)
8 changes: 5 additions & 3 deletions modelopt/torch/quantization/tensor_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class ScaledE4M3Function(Function):
"""E4M3fy input with scale."""

@staticmethod
@symbolic_helper.parse_args("v", "t", "t", "i", "i", "s", "b")
@symbolic_helper.parse_args("v", "t", "t", "is", "i", "i", "s", "b")
def symbolic(
g,
inputs,
Expand All @@ -416,11 +416,12 @@ def symbolic(
M=3, # noqa: N803
trt_high_precision_dtype=None,
pass_through_bwd=False,
block_sizes=None,
):
"""ONNX symbolic function."""
from .export_onnx import export_fp8

return export_fp8(g, inputs, amax, trt_high_precision_dtype)
return export_fp8(g, inputs, amax, trt_high_precision_dtype, block_sizes)

@staticmethod
# Default values could cause errors from TorchDynamo during torch.export
Expand All @@ -433,6 +434,7 @@ def forward(
M, # noqa: N803
trt_high_precision_dtype=None,
pass_through_bwd=False,
block_sizes=None,
):
"""Forward method."""
if E != 4 or M != 3:
Expand Down Expand Up @@ -460,7 +462,7 @@ def forward(
@staticmethod
def backward(ctx, grad_outputs):
"""Implements straight through estimation with clipping."""
return _fake_quant_backward_function(ctx, grad_outputs, num_args=7)
return _fake_quant_backward_function(ctx, grad_outputs, num_args=8)


def _dynamic_block_quantize_forward(
Expand Down