diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index 0c151f8d..0f5eefa0 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -31,6 +31,21 @@ "algorithm": "max", } +FP8_SAGE_DEFAULT_CONFIG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, + "*input_quantizer": {"num_bits": (4, 3), "axis": None}, + "*output_quantizer": {"enable": False}, + "*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}}, + "*softmax_quantizer": { + "num_bits": (4, 3), + "axis": None, + }, + "default": {"enable": False}, + }, + "algorithm": "max", +} + INT8_DEFAULT_CONFIG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index f94a4a1a..fadbeda6 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -939,7 +939,7 @@ def forward_loop(mod): backbone, model_config.model_type, quant_config.format, - quantize_mha=QuantizationConfig.quantize_mha, + quantize_mha=quant_config.quantize_mha, ) logger.info("Quantization process completed successfully!") diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index fe9bd927..a122865a 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -235,6 +235,35 @@ def _fp8_quantize( return q_op +def _fp8_block_quantize( + g: torch.onnx._internal.jit_utils.GraphContext, + inputs: torch.Value, + trt_high_precision_dtype: str, + block_sizes: list, +): + """Helper Function for Quantization.""" + output_shape = sym_help._get_tensor_sizes(inputs) + + # TRT StronglyType only supports FP16 QDQs + # custom ops, so cast the input if needed. + input_type = inputs.type().scalarType() + assert trt_high_precision_dtype in (input_type, "Float"), ( + "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." + ) + if trt_high_precision_dtype != input_type: + inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype]) + quantized_output, scales_output = g.op( + "trt::TRT_DynamicQuantize", + inputs, + block_shape_i=block_sizes, + outputs=2, + scale_type_i=onnx_dtype_map["Float"], + output_type_i=onnx_dtype_map["Float8"], + ) + quantized_output.setType(inputs.type().with_dtype(torch.uint8).with_sizes(output_shape)) + return quantized_output, scales_output + + def _fp8_dequantize( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, @@ -262,20 +291,57 @@ def _fp8_dequantize( return out +def _fp8_block_dequantize( + g: torch.onnx._internal.jit_utils.GraphContext, + inputs: torch.Value, + scales: torch.Value, + trt_high_precision_dtype: str, + otype: str | None = None, + block_sizes: list = [1, 1, 128, 1], +): + """Helper Function for Dequantization.""" + output_shape = sym_help._get_tensor_sizes(inputs) + assert trt_high_precision_dtype in (otype, "Float"), ( + "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." + ) + out = g.op( + "trt::TRT_BlockDequantize", + inputs, + scales, + block_shape_i=block_sizes, + ).setType( + inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape) + ) + + # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT + # custom ops, so cast the output if needed. + if trt_high_precision_dtype != otype: + out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index] + return out + + def export_fp8( g: torch.onnx._internal.jit_utils.GraphContext, inputs: torch.Value, - amax: float, + amax: float | None, trt_high_precision_dtype: str | None, + block_sizes: list | None, ): """Export quantized model to FP8 ONNX.""" scale = 1.0 if amax is None else 448.0 / float(amax) otype = inputs.type().scalarType() if trt_high_precision_dtype is None: trt_high_precision_dtype = otype - - q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype) - return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype) + if not block_sizes: + q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype) + return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype) + else: + q_tensor, scales_output = _fp8_block_quantize( + g, inputs, trt_high_precision_dtype, block_sizes + ) + return _fp8_block_dequantize( + g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes + ) def scaled_dot_product_attention( @@ -365,11 +431,14 @@ def export_fp8_mha( dropout_p: float = 0.0, is_causal: bool = False, scale: torch._C.Value | None = None, - q_quantized_scale: float = 1.0, - k_quantized_scale: float = 1.0, - v_quantized_scale: float = 1.0, + q_quantized_scale: float | None = 1.0, + k_quantized_scale: float | None = 1.0, + v_quantized_scale: float | None = 1.0, high_precision_flag: str = "Half", disable_fp8_mha: bool = True, + q_block_shape: list | None = None, + k_block_shape: list | None = None, + v_block_shape: list | None = None, ): r"""Export quantized fMHA to FP8 ONNX. @@ -440,10 +509,12 @@ def export_fp8_mha( v_input_dtype = value.type().scalarType() if {q_input_dtype, k_input_dtype, v_input_dtype} != {high_precision_flag}: raise ValueError("The quantized MHA must have 16-bit inputs.") - query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag) + query_scaled = export_fp8( + g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape + ) query_scaled = g.op("Cast", query_scaled, to_i=onnx_dtype_map["Float"]) key_transposed_scaled = export_fp8( - g, key_transposed_scaled, k_quantized_scale, high_precision_flag + g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape ) key_transposed_scaled = g.op("Cast", key_transposed_scaled, to_i=onnx_dtype_map["Float"]) mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) @@ -473,7 +544,8 @@ def export_fp8_mha( if not disable_fp8_mha: # Softmax's output scale is hard coded to 1.0 - attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag) + # We cannot do block quant for the softmax's output + attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None) attn_weight = g.op("Cast", attn_weight, to_i=onnx_dtype_map["Float"]) if dropout_p != 0: @@ -483,7 +555,7 @@ def export_fp8_mha( g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), ) if not disable_fp8_mha: - value = export_fp8(g, value, v_quantized_scale, high_precision_flag) + value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape) value = g.op("Cast", value, to_i=onnx_dtype_map["Float"]) return g.op( "Cast", diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 0635b7c9..882b2469 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -646,6 +646,29 @@ def _real_quantize(self, inputs): self._dequantize = True return outputs + def _get_block_sizes_list(self, shape): + """Convert block_sizes dict to list format based on tensor shape. + + Args: + shape: The tensor shape to use for conversion (can be tuple or torch.Size) + + Returns: + List of block sizes for each dimension, or None if block_sizes is None + + Example: + block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1] + """ + if self.block_sizes is None: + return None + + block_sizes_list = [] + for dim in range(len(shape)): + # Check both positive and negative dimension indices + dim_negative = dim - len(shape) + block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) + block_sizes_list.append(block_size if block_size is not None else 1) + return block_sizes_list + def _fake_quantize(self, inputs): """Fake quantization.""" amax = None @@ -654,7 +677,7 @@ def _fake_quantize(self, inputs): self._validate_amax(amax) if self.block_sizes is not None and self.block_sizes.get("type", "static") == "dynamic": - # Block quantization, including dynamic and static block quantization + # Double scale Block quantization, including dynamic and static block quantization block_size = self.block_sizes.get(-1, None) or self.block_sizes.get( inputs.dim() - 1, None ) @@ -675,6 +698,10 @@ def _fake_quantize(self, inputs): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 + # Convert block_sizes dict to list format + # Use original input shape if available (before reshaping), otherwise use current shape + shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape) + block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes) outputs = scaled_e4m3( inputs, amax, @@ -683,6 +710,7 @@ def _fake_quantize(self, inputs): M, self._trt_high_precision_dtype, self._pass_through_bwd, + block_sizes_list, ) else: @@ -931,9 +959,10 @@ def forward(self, inputs): and self.block_sizes.get("type", None) != "dynamic" and self._fake_quant ): - # Tensor reshaping is required for static block quantization - # Tensor shapes are handled separately by the quantization kernels for dynamic block quantization + # Reshape is required if the logic is not handled in the simulation kernel + # Only MX format and NVFP4 reshape are currently supported by the kernel. self._setup_for_blockquant(inputs) + setattr(self, "_original_input_shape", inputs.shape) inputs = self._process_for_blockquant(inputs) outputs = inputs @@ -974,6 +1003,8 @@ def forward(self, inputs): ): outputs = self._reset_to_original_shape(outputs) + if hasattr(self, "_original_input_shape"): + delattr(self, "_original_input_shape") return outputs def _short_amax(self, fmt=".4f"): diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index 5f1ab5db..7a5777b3 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -114,9 +114,26 @@ def _quantized_sdpa(self, *args, **kwargs): key = self.k_bmm_quantizer(key) value = self.v_bmm_quantizer(value) - q_quantized_scale = self.q_bmm_quantizer._get_amax(query) - k_quantized_scale = self.k_bmm_quantizer._get_amax(key) - v_quantized_scale = self.v_bmm_quantizer._get_amax(value) + if ( + not self.q_bmm_quantizer._dynamic + and not self.k_bmm_quantizer._dynamic + and not self.v_bmm_quantizer._dynamic + ): + q_quantized_scale = self.q_bmm_quantizer._get_amax(query) + k_quantized_scale = self.k_bmm_quantizer._get_amax(key) + v_quantized_scale = self.v_bmm_quantizer._get_amax(value) + else: + assert ( + self.q_bmm_quantizer._dynamic + and self.k_bmm_quantizer._dynamic + and self.v_bmm_quantizer._dynamic + ), "QKV QDQS must be in the same type" + q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None + + # Get block sizes lists for each quantizer if needed + q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] + k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr] + v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr] # We don't need to calibrate the output of softmax return self.bmm2_output_quantizer( @@ -132,6 +149,9 @@ def _quantized_sdpa(self, *args, **kwargs): if hasattr(self.q_bmm_quantizer, "trt_high_precision_dtype") else "Half", self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True, + q_block_sizes, + k_block_sizes, + v_block_sizes, ) ) @@ -185,6 +205,9 @@ def forward( v_quantized_scale=None, high_precision_flag=None, disable_fp8_mha=True, + q_block_shape: list | None = None, + k_block_shape: list | None = None, + v_block_shape: list | None = None, ): """Forward method.""" ctx.save_for_backward(query, key, value, attn_mask) @@ -203,7 +226,9 @@ def forward( ) @staticmethod - @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") + @symbolic_helper.parse_args( + "v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is" + ) def symbolic( g: jit_utils.GraphContext, query: torch._C.Value, @@ -213,11 +238,14 @@ def symbolic( dropout_p: float = 0.0, is_causal: bool = False, scale: torch._C.Value | None = None, - q_quantized_scale: float = 1.0, - k_quantized_scale: float = 1.0, - v_quantized_scale: float = 1.0, + q_quantized_scale: float | None = 1.0, + k_quantized_scale: float | None = 1.0, + v_quantized_scale: float | None = 1.0, high_precision_flag: str = "Half", disable_fp8_mha: bool = True, + q_block_shape: list | None = None, + k_block_shape: list | None = None, + v_block_shape: list | None = None, ): """Symbolic method.""" return export_fp8_mha( @@ -234,4 +262,7 @@ def symbolic( v_quantized_scale, high_precision_flag, disable_fp8_mha, + q_block_shape, + k_block_shape, + v_block_shape, ) diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 9bd73d90..bd25240f 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -412,7 +412,7 @@ class ScaledE4M3Function(Function): """E4M3fy input with scale.""" @staticmethod - @symbolic_helper.parse_args("v", "t", "t", "i", "i", "s", "b") + @symbolic_helper.parse_args("v", "t", "t", "is", "i", "i", "s", "b") def symbolic( g, inputs, @@ -422,11 +422,12 @@ def symbolic( M=3, # noqa: N803 trt_high_precision_dtype=None, pass_through_bwd=False, + block_sizes=None, ): """ONNX symbolic function.""" from .export_onnx import export_fp8 - return export_fp8(g, inputs, amax, trt_high_precision_dtype) + return export_fp8(g, inputs, amax, trt_high_precision_dtype, block_sizes) @staticmethod # Default values could cause errors from TorchDynamo during torch.export @@ -439,6 +440,7 @@ def forward( M, # noqa: N803 trt_high_precision_dtype=None, pass_through_bwd=False, + block_sizes=None, ): """Forward method.""" if E != 4 or M != 3: @@ -466,7 +468,7 @@ def forward( @staticmethod def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=7) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=8) def _dynamic_block_quantize_forward(