From 071f167bfec36bc7aca5c8b17da722f1177dbd41 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 15 Sep 2025 22:57:37 +0000 Subject: [PATCH 01/12] Add Sage Attn ONNX & Fixed a bug in diffusers Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/config.py | 15 ++++ examples/diffusers/quantization/quantize.py | 2 +- modelopt/torch/quantization/export_onnx.py | 89 ++++++++++++++++--- .../nn/modules/tensor_quantizer.py | 36 +++++++- .../torch/quantization/plugins/diffusers.py | 35 ++++++-- modelopt/torch/quantization/tensor_quant.py | 6 +- 6 files changed, 160 insertions(+), 23 deletions(-) diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index 0c151f8d5..ecfd75f4f 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -31,6 +31,21 @@ "algorithm": "max", } +FP8_SAGE_DEFAULT_CONFIG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, + "*input_quantizer": {"num_bits": (4, 3), "axis": None}, + "*output_quantizer": {"enable": False}, + "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}}, + "*softmax_quantizer": { + "num_bits": (4, 3), + "axis": None, + }, + "default": {"enable": False}, + }, + "algorithm": "max", +} + INT8_DEFAULT_CONFIG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index f94a4a1ad..fadbeda66 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -939,7 +939,7 @@ def forward_loop(mod): backbone, model_config.model_type, quant_config.format, - quantize_mha=QuantizationConfig.quantize_mha, + quantize_mha=quant_config.quantize_mha, ) logger.info("Quantization process completed successfully!") diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index 984dba0b7..e1fe73828 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -234,6 +234,34 @@ def _fp8_quantize( ) return q_op +def _fp8_block_quantize( + g: torch.onnx._internal.jit_utils.GraphContext, + inputs: torch.Value, + trt_high_precision_dtype: str, + block_sizes: list, +): + """Helper Function for Quantization.""" + output_shape = sym_help._get_tensor_sizes(inputs) + + # TRT StronglyType only supports FP16 QDQs + # custom ops, so cast the input if needed. + input_type = inputs.type().scalarType() + assert trt_high_precision_dtype in (input_type, "Float"), ( + "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." + ) + if trt_high_precision_dtype != input_type: + inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype]) + quantized_output, scales_output = g.op( + "trt::TRT_DynamicQuantize", + inputs, + block_shape_i=block_sizes, + outputs=2, + scale_type_i=onnx_dtype_map["Float"], + output_type_i=onnx_dtype_map["Float8"], + ) + quantized_output.setType(inputs.type().with_dtype(torch.uint8).with_sizes(output_shape)) + return quantized_output, scales_output + def _fp8_dequantize( g: torch.onnx._internal.jit_utils.GraphContext, @@ -261,21 +289,58 @@ def _fp8_dequantize( out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index] return out +def _fp8_block_dequantize( + g: torch.onnx._internal.jit_utils.GraphContext, + inputs: torch.Value, + scales: torch.Value, + trt_high_precision_dtype: str, + otype: str | None = None, + block_sizes: list = [1,1,128,1] +): + """Helper Function for Dequantization.""" + output_shape = sym_help._get_tensor_sizes(inputs) + assert trt_high_precision_dtype in (otype, "Float"), ( + "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." + ) + out = g.op( + "trt::TRT_BlockDequantize", + inputs, + scales, + block_shape_i=block_sizes, + ).setType( + inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape) + ) + + # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT + # custom ops, so cast the output if needed. + if trt_high_precision_dtype != otype: + out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index] + return out + def export_fp8( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, - amax: float, + amax: float | None, trt_high_precision_dtype: str | None, + block_sizes: list | None, ): """Export quantized model to FP8 ONNX.""" scale = 1.0 if amax is None else 448.0 / float(amax) otype = inputs.type().scalarType() if trt_high_precision_dtype is None: trt_high_precision_dtype = otype + if not block_sizes: + q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype) + return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype) + else: + q_tensor, scales_output = _fp8_block_quantize( + g, inputs, trt_high_precision_dtype, block_sizes + ) + return _fp8_block_dequantize( + g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes + ) - q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype) - return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype) def scaled_dot_product_attention( @@ -360,11 +425,14 @@ def export_fp8_mha( dropout_p: float = 0.0, is_causal: bool = False, scale: torch._C.Value | None = None, - q_quantized_scale: float = 1.0, - k_quantized_scale: float = 1.0, - v_quantized_scale: float = 1.0, + q_quantized_scale: float | None = 1.0, + k_quantized_scale: float | None = 1.0, + v_quantized_scale: float | None = 1.0, high_precision_flag: str = "Half", disable_fp8_mha: bool = True, + q_block_shape: list | None = None, + k_block_shape: list | None = None, + v_block_shape: list | None = None, ): r"""Export quantized fMHA to FP8 ONNX. @@ -430,10 +498,10 @@ def export_fp8_mha( v_input_dtype = value.type().scalarType() if {q_input_dtype, k_input_dtype, v_input_dtype} != {high_precision_flag}: raise ValueError("The quantized MHA must have 16-bit inputs.") - query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag) + query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape) query_scaled = g.op("Cast", query_scaled, to_i=onnx_dtype_map["Float"]) key_transposed_scaled = export_fp8( - g, key_transposed_scaled, k_quantized_scale, high_precision_flag + g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape ) key_transposed_scaled = g.op("Cast", key_transposed_scaled, to_i=onnx_dtype_map["Float"]) mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) @@ -463,7 +531,8 @@ def export_fp8_mha( if not disable_fp8_mha: # Softmax's output scale is hard coded to 1.0 - attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag) + # We cannot do block quant for the softmax's output + attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None) attn_weight = g.op("Cast", attn_weight, to_i=onnx_dtype_map["Float"]) if dropout_p != 0: @@ -473,7 +542,7 @@ def export_fp8_mha( g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), ) if not disable_fp8_mha: - value = export_fp8(g, value, v_quantized_scale, high_precision_flag) + value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape) value = g.op("Cast", value, to_i=onnx_dtype_map["Float"]) return g.op( "Cast", diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index b1b2543ac..f9eb845e6 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -648,6 +648,29 @@ def _real_quantize(self, inputs): self._dequantize = True return outputs + def _get_block_sizes_list(self, shape): + """Convert block_sizes dict to list format based on tensor shape. + + Args: + shape: The tensor shape to use for conversion (can be tuple or torch.Size) + + Returns: + List of block sizes for each dimension, or None if block_sizes is None + + Example: + block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1] + """ + if self.block_sizes is None: + return None + + block_sizes_list = [] + for dim in range(len(shape)): + # Check both positive and negative dimension indices + dim_negative = dim - len(shape) + block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) + block_sizes_list.append(block_size if block_size is not None else 1) + return block_sizes_list + def _fake_quantize(self, inputs): """Fake quantization.""" amax = None @@ -656,7 +679,7 @@ def _fake_quantize(self, inputs): self._validate_amax(amax) if self.block_sizes is not None and self.block_sizes.get("type", "static") == "dynamic": - # Block quantization, including dynamic and static block quantization + # Double scale Block quantization, including dynamic and static block quantization block_size = self.block_sizes.get(-1, None) or self.block_sizes.get( inputs.dim() - 1, None ) @@ -677,9 +700,14 @@ def _fake_quantize(self, inputs): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 + # Convert block_sizes dict to list format + # Use original input shape if available (before reshaping), otherwise use current shape + shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape) + block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes) outputs = scaled_e4m3( inputs, amax, + block_sizes_list, self._get_bias(inputs), E, M, @@ -928,9 +956,9 @@ def forward(self, inputs): and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): - # Tensor reshaping is required for static block quantization - # Tensor shapes are handled separately by the quantization kernels for dynamic block quantization + # Reshape is required if the logic isn’t handled in the simulation kernel self._setup_for_blockquant(inputs) + setattr(self, "_original_input_shape", inputs.shape) inputs = self._process_for_blockquant(inputs) outputs = inputs @@ -971,6 +999,8 @@ def forward(self, inputs): ): outputs = self._reset_to_original_shape(outputs) + if hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape") return outputs def _short_amax(self, fmt=".4f"): diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index 5f1ab5db1..e3f375c1e 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -114,9 +114,18 @@ def _quantized_sdpa(self, *args, **kwargs): key = self.k_bmm_quantizer(key) value = self.v_bmm_quantizer(value) - q_quantized_scale = self.q_bmm_quantizer._get_amax(query) - k_quantized_scale = self.k_bmm_quantizer._get_amax(key) - v_quantized_scale = self.v_bmm_quantizer._get_amax(value) + if not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic: + q_quantized_scale = self.q_bmm_quantizer._get_amax(query) + k_quantized_scale = self.k_bmm_quantizer._get_amax(key) + v_quantized_scale = self.v_bmm_quantizer._get_amax(value) + else: + assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type" + q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None + + # Get block sizes lists for each quantizer if needed + q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) + k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) + v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # We don't need to calibrate the output of softmax return self.bmm2_output_quantizer( @@ -132,6 +141,9 @@ def _quantized_sdpa(self, *args, **kwargs): if hasattr(self.q_bmm_quantizer, "trt_high_precision_dtype") else "Half", self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True, + q_block_sizes, + k_block_sizes, + v_block_sizes, ) ) @@ -185,6 +197,9 @@ def forward( v_quantized_scale=None, high_precision_flag=None, disable_fp8_mha=True, + q_block_shape: list | None = None, + k_block_shape: list | None = None, + v_block_shape: list | None = None, ): """Forward method.""" ctx.save_for_backward(query, key, value, attn_mask) @@ -203,7 +218,7 @@ def forward( ) @staticmethod - @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") + @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is") def symbolic( g: jit_utils.GraphContext, query: torch._C.Value, @@ -213,11 +228,14 @@ def symbolic( dropout_p: float = 0.0, is_causal: bool = False, scale: torch._C.Value | None = None, - q_quantized_scale: float = 1.0, - k_quantized_scale: float = 1.0, - v_quantized_scale: float = 1.0, + q_quantized_scale: float | None = 1.0, + k_quantized_scale: float | None = 1.0, + v_quantized_scale: float | None = 1.0, high_precision_flag: str = "Half", disable_fp8_mha: bool = True, + q_block_shape: list | None = None, + k_block_shape: list | None = None, + v_block_shape: list | None = None, ): """Symbolic method.""" return export_fp8_mha( @@ -234,4 +252,7 @@ def symbolic( v_quantized_scale, high_precision_flag, disable_fp8_mha, + q_block_shape, + k_block_shape, + v_block_shape, ) diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 9bd73d909..f5ee14595 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -412,11 +412,12 @@ class ScaledE4M3Function(Function): """E4M3fy input with scale.""" @staticmethod - @symbolic_helper.parse_args("v", "t", "t", "i", "i", "s", "b") + @symbolic_helper.parse_args("v", "t", "t", "is", "i", "i", "s", "b") def symbolic( g, inputs, amax=None, + block_sizes=None, bias=None, E=4, # noqa: N803 M=3, # noqa: N803 @@ -426,7 +427,7 @@ def symbolic( """ONNX symbolic function.""" from .export_onnx import export_fp8 - return export_fp8(g, inputs, amax, trt_high_precision_dtype) + return export_fp8(g, inputs, amax, trt_high_precision_dtype, block_sizes) @staticmethod # Default values could cause errors from TorchDynamo during torch.export @@ -434,6 +435,7 @@ def forward( ctx, inputs, amax, + block_sizes, bias, E, # noqa: N803 M, # noqa: N803 From 831c32d708005f3a0ac33004d837e9346d351661 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 15 Sep 2025 23:27:15 +0000 Subject: [PATCH 02/12] Lint Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/config.py | 2 +- modelopt/torch/quantization/export_onnx.py | 13 ++++++---- .../nn/modules/tensor_quantizer.py | 12 ++++----- .../torch/quantization/plugins/diffusers.py | 26 +++++++++++++------ 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index ecfd75f4f..0f5eefa0f 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -36,7 +36,7 @@ "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, "*input_quantizer": {"num_bits": (4, 3), "axis": None}, "*output_quantizer": {"enable": False}, - "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}}, + "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}}, "*softmax_quantizer": { "num_bits": (4, 3), "axis": None, diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index e1fe73828..0b568e72b 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -234,6 +234,7 @@ def _fp8_quantize( ) return q_op + def _fp8_block_quantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, @@ -289,13 +290,14 @@ def _fp8_dequantize( out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index] return out + def _fp8_block_dequantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, scales: torch.Value, trt_high_precision_dtype: str, otype: str | None = None, - block_sizes: list = [1,1,128,1] + block_sizes: list = [1, 1, 128, 1], ): """Helper Function for Dequantization.""" output_shape = sym_help._get_tensor_sizes(inputs) @@ -339,8 +341,7 @@ def export_fp8( ) return _fp8_block_dequantize( g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes - ) - + ) def scaled_dot_product_attention( @@ -498,7 +499,9 @@ def export_fp8_mha( v_input_dtype = value.type().scalarType() if {q_input_dtype, k_input_dtype, v_input_dtype} != {high_precision_flag}: raise ValueError("The quantized MHA must have 16-bit inputs.") - query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape) + query_scaled = export_fp8( + g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape + ) query_scaled = g.op("Cast", query_scaled, to_i=onnx_dtype_map["Float"]) key_transposed_scaled = export_fp8( g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape @@ -531,7 +534,7 @@ def export_fp8_mha( if not disable_fp8_mha: # Softmax's output scale is hard coded to 1.0 - # We cannot do block quant for the softmax's output + # We cannot do block quant for the softmax's output attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None) attn_weight = g.op("Cast", attn_weight, to_i=onnx_dtype_map["Float"]) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index f9eb845e6..8c7ca7b46 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -650,19 +650,19 @@ def _real_quantize(self, inputs): def _get_block_sizes_list(self, shape): """Convert block_sizes dict to list format based on tensor shape. - + Args: shape: The tensor shape to use for conversion (can be tuple or torch.Size) - + Returns: List of block sizes for each dimension, or None if block_sizes is None - + Example: block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1] """ if self.block_sizes is None: return None - + block_sizes_list = [] for dim in range(len(shape)): # Check both positive and negative dimension indices @@ -670,7 +670,7 @@ def _get_block_sizes_list(self, shape): block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) block_sizes_list.append(block_size if block_size is not None else 1) return block_sizes_list - + def _fake_quantize(self, inputs): """Fake quantization.""" amax = None @@ -956,7 +956,7 @@ def forward(self, inputs): and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): - # Reshape is required if the logic isn’t handled in the simulation kernel + # Reshape is required if the logic isnt handled in the simulation kernel self._setup_for_blockquant(inputs) setattr(self, "_original_input_shape", inputs.shape) inputs = self._process_for_blockquant(inputs) diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index e3f375c1e..7a5777b30 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -114,18 +114,26 @@ def _quantized_sdpa(self, *args, **kwargs): key = self.k_bmm_quantizer(key) value = self.v_bmm_quantizer(value) - if not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic: + if ( + not self.q_bmm_quantizer._dynamic + and not self.k_bmm_quantizer._dynamic + and not self.v_bmm_quantizer._dynamic + ): q_quantized_scale = self.q_bmm_quantizer._get_amax(query) k_quantized_scale = self.k_bmm_quantizer._get_amax(key) v_quantized_scale = self.v_bmm_quantizer._get_amax(value) else: - assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type" + assert ( + self.q_bmm_quantizer._dynamic + and self.k_bmm_quantizer._dynamic + and self.v_bmm_quantizer._dynamic + ), "QKV QDQS must be in the same type" q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None - + # Get block sizes lists for each quantizer if needed - q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) - k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) - v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) + q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] + k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr] + v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr] # We don't need to calibrate the output of softmax return self.bmm2_output_quantizer( @@ -142,7 +150,7 @@ def _quantized_sdpa(self, *args, **kwargs): else "Half", self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True, q_block_sizes, - k_block_sizes, + k_block_sizes, v_block_sizes, ) ) @@ -218,7 +226,9 @@ def forward( ) @staticmethod - @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is") + @symbolic_helper.parse_args( + "v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is" + ) def symbolic( g: jit_utils.GraphContext, query: torch._C.Value, From d2c6e0f4afe54913df31ec9fd130ade77f072fa7 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Tue, 16 Sep 2025 04:20:07 +0000 Subject: [PATCH 03/12] Fixed test cases fail Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 2 +- modelopt/torch/quantization/tensor_quant.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 8c7ca7b46..5436d3131 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -707,12 +707,12 @@ def _fake_quantize(self, inputs): outputs = scaled_e4m3( inputs, amax, - block_sizes_list, self._get_bias(inputs), E, M, self._trt_high_precision_dtype, self._pass_through_bwd, + block_sizes_list, ) else: diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index f5ee14595..09369dd86 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -417,12 +417,12 @@ def symbolic( g, inputs, amax=None, - block_sizes=None, bias=None, E=4, # noqa: N803 M=3, # noqa: N803 trt_high_precision_dtype=None, pass_through_bwd=False, + block_sizes=None, ): """ONNX symbolic function.""" from .export_onnx import export_fp8 @@ -435,12 +435,12 @@ def forward( ctx, inputs, amax, - block_sizes, bias, E, # noqa: N803 M, # noqa: N803 trt_high_precision_dtype=None, pass_through_bwd=False, + block_sizes=None, ): """Forward method.""" if E != 4 or M != 3: From 94ec97dfce0a18591f6a91279cf1805efcce4e7e Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 17 Sep 2025 05:50:31 +0000 Subject: [PATCH 04/12] lint Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 2a1fd694d..9cfe764be 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -959,7 +959,7 @@ def forward(self, inputs): and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): - # Reshape is required if the logic isnt handled in the simulation kernel + # Reshape is required if the logic is not handled in the simulation kernel self._setup_for_blockquant(inputs) setattr(self, "_original_input_shape", inputs.shape) inputs = self._process_for_blockquant(inputs) From 35f3da231f6792e40b1f374f82ef055e4d5ff528 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 17 Sep 2025 05:55:32 +0000 Subject: [PATCH 05/12] update comment Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 9cfe764be..882b24692 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -960,6 +960,7 @@ def forward(self, inputs): and self._fake_quant ): # Reshape is required if the logic is not handled in the simulation kernel + # Only MX format and NVFP4 reshape are currently supported by the kernel. self._setup_for_blockquant(inputs) setattr(self, "_original_input_shape", inputs.shape) inputs = self._process_for_blockquant(inputs) From 25be640b396a88e8fa9d6e190bf57552f0b37a33 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 17 Sep 2025 17:30:41 +0000 Subject: [PATCH 06/12] Fix the test case Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/tensor_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 09369dd86..bd25240fb 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -468,7 +468,7 @@ def forward( @staticmethod def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=7) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=8) def _dynamic_block_quantize_forward( From fb33cac2199897d0fea00356475c3f8ce82a4fc6 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 8 Oct 2025 21:29:27 +0000 Subject: [PATCH 07/12] change the block size to 128 Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index 0f5eefa0f..9a1b6e2ad 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -36,7 +36,7 @@ "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, "*input_quantizer": {"num_bits": (4, 3), "axis": None}, "*output_quantizer": {"enable": False}, - "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}}, + "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 128}}, "*softmax_quantizer": { "num_bits": (4, 3), "axis": None, From df6bf808c9494c7c76b83774a4de6e3c8a08ac3d Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 10 Oct 2025 21:13:11 +0000 Subject: [PATCH 08/12] Update the dim Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 7fbce4273..70cdc15c7 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -629,7 +629,7 @@ def _get_block_sizes_list(self, shape): List of block sizes for each dimension, or None if block_sizes is None Example: - block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1] + block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, -1] """ if self.block_sizes is None: return None @@ -639,7 +639,9 @@ def _get_block_sizes_list(self, shape): # Check both positive and negative dimension indices dim_negative = dim - len(shape) block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) - block_sizes_list.append(block_size if block_size is not None else 1) + # Use -1 for the last dimension if not specified, otherwise use 1 + default_value = -1 if dim == len(shape) - 1 else 1 + block_sizes_list.append(block_size if block_size is not None else default_value) return block_sizes_list def _fake_quantize(self, inputs): From 09c491ab4050f7da7e103bfeb1677be030714172 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 10 Oct 2025 21:53:04 +0000 Subject: [PATCH 09/12] Update the K Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/export_onnx.py | 2 +- .../quantization/nn/modules/tensor_quantizer.py | 12 ++++++++++-- modelopt/torch/quantization/plugins/diffusers.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index 14e31dae1..955cf60a7 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -302,7 +302,7 @@ def _fp8_block_dequantize( scales: torch.Value, trt_high_precision_dtype: str, otype: str | None = None, - block_sizes: list = [1, 1, 128, 1], + block_sizes: list = [1, 1, 128, -1], ): """Helper Function for Dequantization.""" output_shape = sym_help._get_tensor_sizes(inputs) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 70cdc15c7..d1478ac3a 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -619,17 +619,20 @@ def _real_quantize(self, inputs): self._dequantize = True return outputs - def _get_block_sizes_list(self, shape): + def _get_block_sizes_list(self, shape, transpose=False): """Convert block_sizes dict to list format based on tensor shape. Args: shape: The tensor shape to use for conversion (can be tuple or torch.Size) + transpose: If True, swap the last two dimensions' block sizes Returns: List of block sizes for each dimension, or None if block_sizes is None Example: - block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, -1] + block_sizes = {-2: 32} with shape [2, 24, 4608, 128]: + - transpose=False -> [1, 1, 32, -1] + - transpose=True -> [1, 1, -1, 32] """ if self.block_sizes is None: return None @@ -642,6 +645,11 @@ def _get_block_sizes_list(self, shape): # Use -1 for the last dimension if not specified, otherwise use 1 default_value = -1 if dim == len(shape) - 1 else 1 block_sizes_list.append(block_size if block_size is not None else default_value) + + # If transpose is True, swap the last two dimensions + if transpose and len(block_sizes_list) >= 2: + block_sizes_list[-2], block_sizes_list[-1] = block_sizes_list[-1], block_sizes_list[-2] + return block_sizes_list def _fake_quantize(self, inputs): diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index 4291394e1..d8bcb7c40 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -135,7 +135,7 @@ def _quantized_sdpa(self, *args, **kwargs): # Get block sizes lists for each quantizer if needed q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] - k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr] + k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape, transpose=True) # type: ignore[union-attr] v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr] # We don't need to calibrate the output of softmax From b01dcaa31942800cd57589e13581271b42202379 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 5 Nov 2025 23:24:22 +0000 Subject: [PATCH 10/12] Update the GraphContext Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/export_onnx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index 955cf60a7..ded974b7a 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -241,7 +241,7 @@ def _fp8_quantize( def _fp8_block_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, trt_high_precision_dtype: str, block_sizes: list, @@ -270,7 +270,7 @@ def _fp8_block_quantize( def _fp8_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale_inv: float, trt_high_precision_dtype: str, @@ -297,7 +297,7 @@ def _fp8_dequantize( def _fp8_block_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scales: torch.Value, trt_high_precision_dtype: str, @@ -326,7 +326,7 @@ def _fp8_block_dequantize( def export_fp8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: float | None, trt_high_precision_dtype: str | None, From f84f6e5771d300e97a10a56ede1499bb0b1d8f20 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 5 Nov 2025 23:27:44 +0000 Subject: [PATCH 11/12] Update the args Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/export_onnx.py | 2 +- modelopt/torch/quantization/plugins/diffusers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index ded974b7a..42e42d907 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -441,7 +441,7 @@ def export_fp8_mha( attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, q_quantized_scale: float | None = 1.0, k_quantized_scale: float | None = 1.0, v_quantized_scale: float | None = 1.0, diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index d8bcb7c40..11519c4c9 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -240,7 +240,7 @@ def symbolic( attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, q_quantized_scale: float | None = 1.0, k_quantized_scale: float | None = 1.0, v_quantized_scale: float | None = 1.0, From 3cebde6388379635cc99fae7762d23fd6ace8d7e Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 5 Nov 2025 23:39:35 +0000 Subject: [PATCH 12/12] update Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/export_onnx.py | 12 ++++++------ modelopt/torch/quantization/plugins/diffusers.py | 12 ++++++------ modelopt/torch/quantization/tensor_quant.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index 42e42d907..d2c4460a7 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -442,14 +442,14 @@ def export_fp8_mha( dropout_p: float = 0.0, is_causal: bool = False, scale: "torch._C.Value | None" = None, - q_quantized_scale: float | None = 1.0, - k_quantized_scale: float | None = 1.0, - v_quantized_scale: float | None = 1.0, + q_quantized_scale: "float | None" = 1.0, + k_quantized_scale: "float | None" = 1.0, + v_quantized_scale: "float | None" = 1.0, high_precision_flag: str = "Half", disable_fp8_mha: bool = True, - q_block_shape: list | None = None, - k_block_shape: list | None = None, - v_block_shape: list | None = None, + q_block_shape: "list | None" = None, + k_block_shape: "list | None" = None, + v_block_shape: "list | None" = None, ): r"""Export quantized fMHA to FP8 ONNX. diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index 11519c4c9..846f0f773 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -241,14 +241,14 @@ def symbolic( dropout_p: float = 0.0, is_causal: bool = False, scale: "torch._C.Value | None" = None, - q_quantized_scale: float | None = 1.0, - k_quantized_scale: float | None = 1.0, - v_quantized_scale: float | None = 1.0, + q_quantized_scale: "float | None" = 1.0, + k_quantized_scale: "float | None" = 1.0, + v_quantized_scale: "float | None" = 1.0, high_precision_flag: str = "Half", disable_fp8_mha: bool = True, - q_block_shape: list | None = None, - k_block_shape: list | None = None, - v_block_shape: list | None = None, + q_block_shape: "list | None" = None, + k_block_shape: "list | None" = None, + v_block_shape: "list | None" = None, ): """Symbolic method.""" return export_fp8_mha( diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index c4a3c491b..44bf42184 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -406,7 +406,7 @@ class ScaledE4M3Function(Function): """E4M3fy input with scale.""" @staticmethod - @symbolic_helper.parse_args("v", "t", "t", "is", "i", "i", "s", "b") + @symbolic_helper.parse_args("v", "t", "t", "i", "i", "s", "b", "is") def symbolic( g, inputs,