Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@
"algorithm": "max",
}

FP8_SAGE_DEFAULT_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
"*output_quantizer": {"enable": False},
"*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}},
"*softmax_quantizer": {
"num_bits": (4, 3),
"axis": None,
},
"default": {"enable": False},
},
"algorithm": "max",
}

INT8_DEFAULT_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 8, "axis": 0},
Expand Down
2 changes: 1 addition & 1 deletion examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def forward_loop(mod):
backbone,
model_config.model_type,
quant_config.format,
quantize_mha=QuantizationConfig.quantize_mha,
quantize_mha=quant_config.quantize_mha,
)
logger.info("Quantization process completed successfully!")

Expand Down
94 changes: 83 additions & 11 deletions modelopt/torch/quantization/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,35 @@ def _fp8_quantize(
return q_op


def _fp8_block_quantize(
g: torch.onnx._internal.jit_utils.GraphContext,
inputs: torch.Value,
trt_high_precision_dtype: str,
block_sizes: list,
):
"""Helper Function for Quantization."""
output_shape = sym_help._get_tensor_sizes(inputs)

# TRT StronglyType only supports FP16 QDQs
# custom ops, so cast the input if needed.
input_type = inputs.type().scalarType()
assert trt_high_precision_dtype in (input_type, "Float"), (
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
)
if trt_high_precision_dtype != input_type:
inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype])
quantized_output, scales_output = g.op(
"trt::TRT_DynamicQuantize",
inputs,
block_shape_i=block_sizes,
outputs=2,
scale_type_i=onnx_dtype_map["Float"],
output_type_i=onnx_dtype_map["Float8"],
)
quantized_output.setType(inputs.type().with_dtype(torch.uint8).with_sizes(output_shape))
return quantized_output, scales_output


def _fp8_dequantize(
g: torch.onnx._internal.jit_utils.GraphContext,
inputs: torch.Value,
Expand Down Expand Up @@ -262,20 +291,57 @@ def _fp8_dequantize(
return out


def _fp8_block_dequantize(
g: torch.onnx._internal.jit_utils.GraphContext,
inputs: torch.Value,
scales: torch.Value,
trt_high_precision_dtype: str,
otype: str | None = None,
block_sizes: list = [1, 1, 128, 1],
):
"""Helper Function for Dequantization."""
output_shape = sym_help._get_tensor_sizes(inputs)
assert trt_high_precision_dtype in (otype, "Float"), (
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
)
out = g.op(
"trt::TRT_BlockDequantize",
inputs,
scales,
block_shape_i=block_sizes,
).setType(
inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape)
)

# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the output if needed.
if trt_high_precision_dtype != otype:
out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index]
return out


def export_fp8(
g: torch.onnx._internal.jit_utils.GraphContext,
inputs: torch.Value,
amax: float,
amax: float | None,
trt_high_precision_dtype: str | None,
block_sizes: list | None,
):
"""Export quantized model to FP8 ONNX."""
scale = 1.0 if amax is None else 448.0 / float(amax)
otype = inputs.type().scalarType()
if trt_high_precision_dtype is None:
trt_high_precision_dtype = otype

q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
if not block_sizes:
q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
else:
q_tensor, scales_output = _fp8_block_quantize(
g, inputs, trt_high_precision_dtype, block_sizes
)
return _fp8_block_dequantize(
g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes
)


def scaled_dot_product_attention(
Expand Down Expand Up @@ -360,11 +426,14 @@ def export_fp8_mha(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: torch._C.Value | None = None,
q_quantized_scale: float = 1.0,
k_quantized_scale: float = 1.0,
v_quantized_scale: float = 1.0,
q_quantized_scale: float | None = 1.0,
k_quantized_scale: float | None = 1.0,
v_quantized_scale: float | None = 1.0,
high_precision_flag: str = "Half",
disable_fp8_mha: bool = True,
q_block_shape: list | None = None,
k_block_shape: list | None = None,
v_block_shape: list | None = None,
):
r"""Export quantized fMHA to FP8 ONNX.

Expand Down Expand Up @@ -430,10 +499,12 @@ def export_fp8_mha(
v_input_dtype = value.type().scalarType()
if {q_input_dtype, k_input_dtype, v_input_dtype} != {high_precision_flag}:
raise ValueError("The quantized MHA must have 16-bit inputs.")
query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag)
query_scaled = export_fp8(
g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape
)
query_scaled = g.op("Cast", query_scaled, to_i=onnx_dtype_map["Float"])
key_transposed_scaled = export_fp8(
g, key_transposed_scaled, k_quantized_scale, high_precision_flag
g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape
)
key_transposed_scaled = g.op("Cast", key_transposed_scaled, to_i=onnx_dtype_map["Float"])
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
Expand Down Expand Up @@ -463,7 +534,8 @@ def export_fp8_mha(

if not disable_fp8_mha:
# Softmax's output scale is hard coded to 1.0
attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag)
# We cannot do block quant for the softmax's output
attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None)
attn_weight = g.op("Cast", attn_weight, to_i=onnx_dtype_map["Float"])

if dropout_p != 0:
Expand All @@ -473,7 +545,7 @@ def export_fp8_mha(
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
)
if not disable_fp8_mha:
value = export_fp8(g, value, v_quantized_scale, high_precision_flag)
value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)
value = g.op("Cast", value, to_i=onnx_dtype_map["Float"])
return g.op(
"Cast",
Expand Down
36 changes: 33 additions & 3 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,29 @@ def _real_quantize(self, inputs):
self._dequantize = True
return outputs

def _get_block_sizes_list(self, shape):
"""Convert block_sizes dict to list format based on tensor shape.

Args:
shape: The tensor shape to use for conversion (can be tuple or torch.Size)

Returns:
List of block sizes for each dimension, or None if block_sizes is None

Example:
block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1]
"""
if self.block_sizes is None:
return None

block_sizes_list = []
for dim in range(len(shape)):
# Check both positive and negative dimension indices
dim_negative = dim - len(shape)
block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
block_sizes_list.append(block_size if block_size is not None else 1)
return block_sizes_list

def _fake_quantize(self, inputs):
"""Fake quantization."""
amax = None
Expand All @@ -656,7 +679,7 @@ def _fake_quantize(self, inputs):
self._validate_amax(amax)

if self.block_sizes is not None and self.block_sizes.get("type", "static") == "dynamic":
# Block quantization, including dynamic and static block quantization
# Double scale Block quantization, including dynamic and static block quantization
block_size = self.block_sizes.get(-1, None) or self.block_sizes.get(
inputs.dim() - 1, None
)
Expand All @@ -677,6 +700,10 @@ def _fake_quantize(self, inputs):
# Float-point quantization, e.g., FP8
E, M = self._num_bits # noqa: N806

# Convert block_sizes dict to list format
# Use original input shape if available (before reshaping), otherwise use current shape
shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape)
block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes)
outputs = scaled_e4m3(
inputs,
amax,
Expand All @@ -685,6 +712,7 @@ def _fake_quantize(self, inputs):
M,
self._trt_high_precision_dtype,
self._pass_through_bwd,
block_sizes_list,
)

else:
Expand Down Expand Up @@ -928,9 +956,9 @@ def forward(self, inputs):
and self.block_sizes.get("type", None) != "dynamic"
and self._fake_quant
):
# Tensor reshaping is required for static block quantization
# Tensor shapes are handled separately by the quantization kernels for dynamic block quantization
# Reshape is required if the logic isnt handled in the simulation kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: isn't

How do we check when the reshape is or isn't handled in the simulation kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernels are here: https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/modelopt/torch/quantization/src

The kernels only support MX format and reshaped NVFP4. Other formats require using Torch reshape. I think the previously comment made a false statement, just add one more comment.

self._setup_for_blockquant(inputs)
setattr(self, "_original_input_shape", inputs.shape)
inputs = self._process_for_blockquant(inputs)

outputs = inputs
Expand Down Expand Up @@ -971,6 +999,8 @@ def forward(self, inputs):
):
outputs = self._reset_to_original_shape(outputs)

if hasattr(self, "_original_input_shape"):
delattr(self, "_original_input_shape")
return outputs

def _short_amax(self, fmt=".4f"):
Expand Down
45 changes: 38 additions & 7 deletions modelopt/torch/quantization/plugins/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,26 @@ def _quantized_sdpa(self, *args, **kwargs):
key = self.k_bmm_quantizer(key)
value = self.v_bmm_quantizer(value)

q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
if (
not self.q_bmm_quantizer._dynamic
and not self.k_bmm_quantizer._dynamic
and not self.v_bmm_quantizer._dynamic
):
q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
else:
assert (
self.q_bmm_quantizer._dynamic
and self.k_bmm_quantizer._dynamic
and self.v_bmm_quantizer._dynamic
), "QKV QDQS must be in the same type"
q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None

# Get block sizes lists for each quantizer if needed
q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr]
k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr]
v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr]

# We don't need to calibrate the output of softmax
return self.bmm2_output_quantizer(
Expand All @@ -132,6 +149,9 @@ def _quantized_sdpa(self, *args, **kwargs):
if hasattr(self.q_bmm_quantizer, "trt_high_precision_dtype")
else "Half",
self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True,
q_block_sizes,
k_block_sizes,
v_block_sizes,
)
)

Expand Down Expand Up @@ -185,6 +205,9 @@ def forward(
v_quantized_scale=None,
high_precision_flag=None,
disable_fp8_mha=True,
q_block_shape: list | None = None,
k_block_shape: list | None = None,
v_block_shape: list | None = None,
):
"""Forward method."""
ctx.save_for_backward(query, key, value, attn_mask)
Expand All @@ -203,7 +226,9 @@ def forward(
)

@staticmethod
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b")
@symbolic_helper.parse_args(
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this pattern is changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check parse_args:
https://github.com/pytorch/pytorch/blob/main/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py#L301

Each value represents the input type, which helps Torch trace the graph more effectively.
Since the block shape is a list of integers, we add three "is".

)
def symbolic(
g: jit_utils.GraphContext,
query: torch._C.Value,
Expand All @@ -213,11 +238,14 @@ def symbolic(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: torch._C.Value | None = None,
q_quantized_scale: float = 1.0,
k_quantized_scale: float = 1.0,
v_quantized_scale: float = 1.0,
q_quantized_scale: float | None = 1.0,
k_quantized_scale: float | None = 1.0,
v_quantized_scale: float | None = 1.0,
high_precision_flag: str = "Half",
disable_fp8_mha: bool = True,
q_block_shape: list | None = None,
k_block_shape: list | None = None,
v_block_shape: list | None = None,
):
"""Symbolic method."""
return export_fp8_mha(
Expand All @@ -234,4 +262,7 @@ def symbolic(
v_quantized_scale,
high_precision_flag,
disable_fp8_mha,
q_block_shape,
k_block_shape,
v_block_shape,
)
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/tensor_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ class ScaledE4M3Function(Function):
"""E4M3fy input with scale."""

@staticmethod
@symbolic_helper.parse_args("v", "t", "t", "i", "i", "s", "b")
@symbolic_helper.parse_args("v", "t", "t", "is", "i", "i", "s", "b")
def symbolic(
g,
inputs,
Expand All @@ -422,11 +422,12 @@ def symbolic(
M=3, # noqa: N803
trt_high_precision_dtype=None,
pass_through_bwd=False,
block_sizes=None,
):
"""ONNX symbolic function."""
from .export_onnx import export_fp8

return export_fp8(g, inputs, amax, trt_high_precision_dtype)
return export_fp8(g, inputs, amax, trt_high_precision_dtype, block_sizes)

@staticmethod
# Default values could cause errors from TorchDynamo during torch.export
Expand All @@ -439,6 +440,7 @@ def forward(
M, # noqa: N803
trt_high_precision_dtype=None,
pass_through_bwd=False,
block_sizes=None,
):
"""Forward method."""
if E != 4 or M != 3:
Expand Down
Loading