Skip to content

Commit 071f167

Browse files
committed
Add Sage Attn ONNX & Fixed a bug in diffusers
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 76e8ce2 commit 071f167

File tree

6 files changed

+160
-23
lines changed

6 files changed

+160
-23
lines changed

examples/diffusers/quantization/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@
3131
"algorithm": "max",
3232
}
3333

34+
FP8_SAGE_DEFAULT_CONFIG = {
35+
"quant_cfg": {
36+
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
37+
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
38+
"*output_quantizer": {"enable": False},
39+
"*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}},
40+
"*softmax_quantizer": {
41+
"num_bits": (4, 3),
42+
"axis": None,
43+
},
44+
"default": {"enable": False},
45+
},
46+
"algorithm": "max",
47+
}
48+
3449
INT8_DEFAULT_CONFIG = {
3550
"quant_cfg": {
3651
"*weight_quantizer": {"num_bits": 8, "axis": 0},

examples/diffusers/quantization/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def forward_loop(mod):
939939
backbone,
940940
model_config.model_type,
941941
quant_config.format,
942-
quantize_mha=QuantizationConfig.quantize_mha,
942+
quantize_mha=quant_config.quantize_mha,
943943
)
944944
logger.info("Quantization process completed successfully!")
945945

modelopt/torch/quantization/export_onnx.py

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,34 @@ def _fp8_quantize(
234234
)
235235
return q_op
236236

237+
def _fp8_block_quantize(
238+
g: torch.onnx._internal.jit_utils.GraphContext,
239+
inputs: torch.Value,
240+
trt_high_precision_dtype: str,
241+
block_sizes: list,
242+
):
243+
"""Helper Function for Quantization."""
244+
output_shape = sym_help._get_tensor_sizes(inputs)
245+
246+
# TRT StronglyType only supports FP16 QDQs
247+
# custom ops, so cast the input if needed.
248+
input_type = inputs.type().scalarType()
249+
assert trt_high_precision_dtype in (input_type, "Float"), (
250+
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
251+
)
252+
if trt_high_precision_dtype != input_type:
253+
inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype])
254+
quantized_output, scales_output = g.op(
255+
"trt::TRT_DynamicQuantize",
256+
inputs,
257+
block_shape_i=block_sizes,
258+
outputs=2,
259+
scale_type_i=onnx_dtype_map["Float"],
260+
output_type_i=onnx_dtype_map["Float8"],
261+
)
262+
quantized_output.setType(inputs.type().with_dtype(torch.uint8).with_sizes(output_shape))
263+
return quantized_output, scales_output
264+
237265

238266
def _fp8_dequantize(
239267
g: torch.onnx._internal.jit_utils.GraphContext,
@@ -261,21 +289,58 @@ def _fp8_dequantize(
261289
out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index]
262290
return out
263291

292+
def _fp8_block_dequantize(
293+
g: torch.onnx._internal.jit_utils.GraphContext,
294+
inputs: torch.Value,
295+
scales: torch.Value,
296+
trt_high_precision_dtype: str,
297+
otype: str | None = None,
298+
block_sizes: list = [1,1,128,1]
299+
):
300+
"""Helper Function for Dequantization."""
301+
output_shape = sym_help._get_tensor_sizes(inputs)
302+
assert trt_high_precision_dtype in (otype, "Float"), (
303+
"TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
304+
)
305+
out = g.op(
306+
"trt::TRT_BlockDequantize",
307+
inputs,
308+
scales,
309+
block_shape_i=block_sizes,
310+
).setType(
311+
inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape)
312+
)
313+
314+
# DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
315+
# custom ops, so cast the output if needed.
316+
if trt_high_precision_dtype != otype:
317+
out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index]
318+
return out
319+
264320

265321
def export_fp8(
266322
g: torch.onnx._internal.jit_utils.GraphContext,
267323
inputs: torch.Value,
268-
amax: float,
324+
amax: float | None,
269325
trt_high_precision_dtype: str | None,
326+
block_sizes: list | None,
270327
):
271328
"""Export quantized model to FP8 ONNX."""
272329
scale = 1.0 if amax is None else 448.0 / float(amax)
273330
otype = inputs.type().scalarType()
274331
if trt_high_precision_dtype is None:
275332
trt_high_precision_dtype = otype
333+
if not block_sizes:
334+
q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
335+
return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
336+
else:
337+
q_tensor, scales_output = _fp8_block_quantize(
338+
g, inputs, trt_high_precision_dtype, block_sizes
339+
)
340+
return _fp8_block_dequantize(
341+
g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes
342+
)
276343

277-
q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype)
278-
return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype)
279344

280345

281346
def scaled_dot_product_attention(
@@ -360,11 +425,14 @@ def export_fp8_mha(
360425
dropout_p: float = 0.0,
361426
is_causal: bool = False,
362427
scale: torch._C.Value | None = None,
363-
q_quantized_scale: float = 1.0,
364-
k_quantized_scale: float = 1.0,
365-
v_quantized_scale: float = 1.0,
428+
q_quantized_scale: float | None = 1.0,
429+
k_quantized_scale: float | None = 1.0,
430+
v_quantized_scale: float | None = 1.0,
366431
high_precision_flag: str = "Half",
367432
disable_fp8_mha: bool = True,
433+
q_block_shape: list | None = None,
434+
k_block_shape: list | None = None,
435+
v_block_shape: list | None = None,
368436
):
369437
r"""Export quantized fMHA to FP8 ONNX.
370438
@@ -430,10 +498,10 @@ def export_fp8_mha(
430498
v_input_dtype = value.type().scalarType()
431499
if {q_input_dtype, k_input_dtype, v_input_dtype} != {high_precision_flag}:
432500
raise ValueError("The quantized MHA must have 16-bit inputs.")
433-
query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag)
501+
query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape)
434502
query_scaled = g.op("Cast", query_scaled, to_i=onnx_dtype_map["Float"])
435503
key_transposed_scaled = export_fp8(
436-
g, key_transposed_scaled, k_quantized_scale, high_precision_flag
504+
g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape
437505
)
438506
key_transposed_scaled = g.op("Cast", key_transposed_scaled, to_i=onnx_dtype_map["Float"])
439507
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
@@ -463,7 +531,8 @@ def export_fp8_mha(
463531

464532
if not disable_fp8_mha:
465533
# Softmax's output scale is hard coded to 1.0
466-
attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag)
534+
# We cannot do block quant for the softmax's output
535+
attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None)
467536
attn_weight = g.op("Cast", attn_weight, to_i=onnx_dtype_map["Float"])
468537

469538
if dropout_p != 0:
@@ -473,7 +542,7 @@ def export_fp8_mha(
473542
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
474543
)
475544
if not disable_fp8_mha:
476-
value = export_fp8(g, value, v_quantized_scale, high_precision_flag)
545+
value = export_fp8(g, value, v_quantized_scale, high_precision_flag, v_block_shape)
477546
value = g.op("Cast", value, to_i=onnx_dtype_map["Float"])
478547
return g.op(
479548
"Cast",

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,29 @@ def _real_quantize(self, inputs):
648648
self._dequantize = True
649649
return outputs
650650

651+
def _get_block_sizes_list(self, shape):
652+
"""Convert block_sizes dict to list format based on tensor shape.
653+
654+
Args:
655+
shape: The tensor shape to use for conversion (can be tuple or torch.Size)
656+
657+
Returns:
658+
List of block sizes for each dimension, or None if block_sizes is None
659+
660+
Example:
661+
block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1]
662+
"""
663+
if self.block_sizes is None:
664+
return None
665+
666+
block_sizes_list = []
667+
for dim in range(len(shape)):
668+
# Check both positive and negative dimension indices
669+
dim_negative = dim - len(shape)
670+
block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
671+
block_sizes_list.append(block_size if block_size is not None else 1)
672+
return block_sizes_list
673+
651674
def _fake_quantize(self, inputs):
652675
"""Fake quantization."""
653676
amax = None
@@ -656,7 +679,7 @@ def _fake_quantize(self, inputs):
656679
self._validate_amax(amax)
657680

658681
if self.block_sizes is not None and self.block_sizes.get("type", "static") == "dynamic":
659-
# Block quantization, including dynamic and static block quantization
682+
# Double scale Block quantization, including dynamic and static block quantization
660683
block_size = self.block_sizes.get(-1, None) or self.block_sizes.get(
661684
inputs.dim() - 1, None
662685
)
@@ -677,9 +700,14 @@ def _fake_quantize(self, inputs):
677700
# Float-point quantization, e.g., FP8
678701
E, M = self._num_bits # noqa: N806
679702

703+
# Convert block_sizes dict to list format
704+
# Use original input shape if available (before reshaping), otherwise use current shape
705+
shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape)
706+
block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes)
680707
outputs = scaled_e4m3(
681708
inputs,
682709
amax,
710+
block_sizes_list,
683711
self._get_bias(inputs),
684712
E,
685713
M,
@@ -928,9 +956,9 @@ def forward(self, inputs):
928956
and self.block_sizes.get("type", None) != "dynamic"
929957
and self._fake_quant
930958
):
931-
# Tensor reshaping is required for static block quantization
932-
# Tensor shapes are handled separately by the quantization kernels for dynamic block quantization
959+
# Reshape is required if the logic isn’t handled in the simulation kernel
933960
self._setup_for_blockquant(inputs)
961+
setattr(self, "_original_input_shape", inputs.shape)
934962
inputs = self._process_for_blockquant(inputs)
935963

936964
outputs = inputs
@@ -971,6 +999,8 @@ def forward(self, inputs):
971999
):
9721000
outputs = self._reset_to_original_shape(outputs)
9731001

1002+
if hasattr(self, "_original_input_shape"):
1003+
delattr(self, "_original_input_shape")
9741004
return outputs
9751005

9761006
def _short_amax(self, fmt=".4f"):

modelopt/torch/quantization/plugins/diffusers.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,18 @@ def _quantized_sdpa(self, *args, **kwargs):
114114
key = self.k_bmm_quantizer(key)
115115
value = self.v_bmm_quantizer(value)
116116

117-
q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
118-
k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
119-
v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
117+
if not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic:
118+
q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
119+
k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
120+
v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
121+
else:
122+
assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type"
123+
q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None
124+
125+
# Get block sizes lists for each quantizer if needed
126+
q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)
127+
k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape)
128+
v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape)
120129

121130
# We don't need to calibrate the output of softmax
122131
return self.bmm2_output_quantizer(
@@ -132,6 +141,9 @@ def _quantized_sdpa(self, *args, **kwargs):
132141
if hasattr(self.q_bmm_quantizer, "trt_high_precision_dtype")
133142
else "Half",
134143
self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True,
144+
q_block_sizes,
145+
k_block_sizes,
146+
v_block_sizes,
135147
)
136148
)
137149

@@ -185,6 +197,9 @@ def forward(
185197
v_quantized_scale=None,
186198
high_precision_flag=None,
187199
disable_fp8_mha=True,
200+
q_block_shape: list | None = None,
201+
k_block_shape: list | None = None,
202+
v_block_shape: list | None = None,
188203
):
189204
"""Forward method."""
190205
ctx.save_for_backward(query, key, value, attn_mask)
@@ -203,7 +218,7 @@ def forward(
203218
)
204219

205220
@staticmethod
206-
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b")
221+
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is")
207222
def symbolic(
208223
g: jit_utils.GraphContext,
209224
query: torch._C.Value,
@@ -213,11 +228,14 @@ def symbolic(
213228
dropout_p: float = 0.0,
214229
is_causal: bool = False,
215230
scale: torch._C.Value | None = None,
216-
q_quantized_scale: float = 1.0,
217-
k_quantized_scale: float = 1.0,
218-
v_quantized_scale: float = 1.0,
231+
q_quantized_scale: float | None = 1.0,
232+
k_quantized_scale: float | None = 1.0,
233+
v_quantized_scale: float | None = 1.0,
219234
high_precision_flag: str = "Half",
220235
disable_fp8_mha: bool = True,
236+
q_block_shape: list | None = None,
237+
k_block_shape: list | None = None,
238+
v_block_shape: list | None = None,
221239
):
222240
"""Symbolic method."""
223241
return export_fp8_mha(
@@ -234,4 +252,7 @@ def symbolic(
234252
v_quantized_scale,
235253
high_precision_flag,
236254
disable_fp8_mha,
255+
q_block_shape,
256+
k_block_shape,
257+
v_block_shape,
237258
)

modelopt/torch/quantization/tensor_quant.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,12 @@ class ScaledE4M3Function(Function):
412412
"""E4M3fy input with scale."""
413413

414414
@staticmethod
415-
@symbolic_helper.parse_args("v", "t", "t", "i", "i", "s", "b")
415+
@symbolic_helper.parse_args("v", "t", "t", "is", "i", "i", "s", "b")
416416
def symbolic(
417417
g,
418418
inputs,
419419
amax=None,
420+
block_sizes=None,
420421
bias=None,
421422
E=4, # noqa: N803
422423
M=3, # noqa: N803
@@ -426,14 +427,15 @@ def symbolic(
426427
"""ONNX symbolic function."""
427428
from .export_onnx import export_fp8
428429

429-
return export_fp8(g, inputs, amax, trt_high_precision_dtype)
430+
return export_fp8(g, inputs, amax, trt_high_precision_dtype, block_sizes)
430431

431432
@staticmethod
432433
# Default values could cause errors from TorchDynamo during torch.export
433434
def forward(
434435
ctx,
435436
inputs,
436437
amax,
438+
block_sizes,
437439
bias,
438440
E, # noqa: N803
439441
M, # noqa: N803

0 commit comments

Comments
 (0)