@@ -234,6 +234,34 @@ def _fp8_quantize(
234
234
)
235
235
return q_op
236
236
237
+ def _fp8_block_quantize (
238
+ g : torch .onnx ._internal .jit_utils .GraphContext ,
239
+ inputs : torch .Value ,
240
+ trt_high_precision_dtype : str ,
241
+ block_sizes : list ,
242
+ ):
243
+ """Helper Function for Quantization."""
244
+ output_shape = sym_help ._get_tensor_sizes (inputs )
245
+
246
+ # TRT StronglyType only supports FP16 QDQs
247
+ # custom ops, so cast the input if needed.
248
+ input_type = inputs .type ().scalarType ()
249
+ assert trt_high_precision_dtype in (input_type , "Float" ), (
250
+ "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
251
+ )
252
+ if trt_high_precision_dtype != input_type :
253
+ inputs = g .op ("Cast" , inputs , to_i = onnx_dtype_map [trt_high_precision_dtype ])
254
+ quantized_output , scales_output = g .op (
255
+ "trt::TRT_DynamicQuantize" ,
256
+ inputs ,
257
+ block_shape_i = block_sizes ,
258
+ outputs = 2 ,
259
+ scale_type_i = onnx_dtype_map ["Float" ],
260
+ output_type_i = onnx_dtype_map ["Float8" ],
261
+ )
262
+ quantized_output .setType (inputs .type ().with_dtype (torch .uint8 ).with_sizes (output_shape ))
263
+ return quantized_output , scales_output
264
+
237
265
238
266
def _fp8_dequantize (
239
267
g : torch .onnx ._internal .jit_utils .GraphContext ,
@@ -261,21 +289,58 @@ def _fp8_dequantize(
261
289
out = g .op ("Cast" , out , to_i = onnx_dtype_map [otype ]) # type: ignore[index]
262
290
return out
263
291
292
+ def _fp8_block_dequantize (
293
+ g : torch .onnx ._internal .jit_utils .GraphContext ,
294
+ inputs : torch .Value ,
295
+ scales : torch .Value ,
296
+ trt_high_precision_dtype : str ,
297
+ otype : str | None = None ,
298
+ block_sizes : list = [1 ,1 ,128 ,1 ]
299
+ ):
300
+ """Helper Function for Dequantization."""
301
+ output_shape = sym_help ._get_tensor_sizes (inputs )
302
+ assert trt_high_precision_dtype in (otype , "Float" ), (
303
+ "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
304
+ )
305
+ out = g .op (
306
+ "trt::TRT_BlockDequantize" ,
307
+ inputs ,
308
+ scales ,
309
+ block_shape_i = block_sizes ,
310
+ ).setType (
311
+ inputs .type ().with_dtype (torch_dtype_map [trt_high_precision_dtype ]).with_sizes (output_shape )
312
+ )
313
+
314
+ # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
315
+ # custom ops, so cast the output if needed.
316
+ if trt_high_precision_dtype != otype :
317
+ out = g .op ("Cast" , out , to_i = onnx_dtype_map [otype ]) # type: ignore[index]
318
+ return out
319
+
264
320
265
321
def export_fp8 (
266
322
g : torch .onnx ._internal .jit_utils .GraphContext ,
267
323
inputs : torch .Value ,
268
- amax : float ,
324
+ amax : float | None ,
269
325
trt_high_precision_dtype : str | None ,
326
+ block_sizes : list | None ,
270
327
):
271
328
"""Export quantized model to FP8 ONNX."""
272
329
scale = 1.0 if amax is None else 448.0 / float (amax )
273
330
otype = inputs .type ().scalarType ()
274
331
if trt_high_precision_dtype is None :
275
332
trt_high_precision_dtype = otype
333
+ if not block_sizes :
334
+ q_tensor = _fp8_quantize (g , inputs , 1.0 / scale , trt_high_precision_dtype )
335
+ return _fp8_dequantize (g , q_tensor , 1.0 / scale , trt_high_precision_dtype , otype )
336
+ else :
337
+ q_tensor , scales_output = _fp8_block_quantize (
338
+ g , inputs , trt_high_precision_dtype , block_sizes
339
+ )
340
+ return _fp8_block_dequantize (
341
+ g , q_tensor , scales_output , trt_high_precision_dtype , otype , block_sizes
342
+ )
276
343
277
- q_tensor = _fp8_quantize (g , inputs , 1.0 / scale , trt_high_precision_dtype )
278
- return _fp8_dequantize (g , q_tensor , 1.0 / scale , trt_high_precision_dtype , otype )
279
344
280
345
281
346
def scaled_dot_product_attention (
@@ -360,11 +425,14 @@ def export_fp8_mha(
360
425
dropout_p : float = 0.0 ,
361
426
is_causal : bool = False ,
362
427
scale : torch ._C .Value | None = None ,
363
- q_quantized_scale : float = 1.0 ,
364
- k_quantized_scale : float = 1.0 ,
365
- v_quantized_scale : float = 1.0 ,
428
+ q_quantized_scale : float | None = 1.0 ,
429
+ k_quantized_scale : float | None = 1.0 ,
430
+ v_quantized_scale : float | None = 1.0 ,
366
431
high_precision_flag : str = "Half" ,
367
432
disable_fp8_mha : bool = True ,
433
+ q_block_shape : list | None = None ,
434
+ k_block_shape : list | None = None ,
435
+ v_block_shape : list | None = None ,
368
436
):
369
437
r"""Export quantized fMHA to FP8 ONNX.
370
438
@@ -430,10 +498,10 @@ def export_fp8_mha(
430
498
v_input_dtype = value .type ().scalarType ()
431
499
if {q_input_dtype , k_input_dtype , v_input_dtype } != {high_precision_flag }:
432
500
raise ValueError ("The quantized MHA must have 16-bit inputs." )
433
- query_scaled = export_fp8 (g , query_scaled , q_quantized_scale , high_precision_flag )
501
+ query_scaled = export_fp8 (g , query_scaled , q_quantized_scale , high_precision_flag , q_block_shape )
434
502
query_scaled = g .op ("Cast" , query_scaled , to_i = onnx_dtype_map ["Float" ])
435
503
key_transposed_scaled = export_fp8 (
436
- g , key_transposed_scaled , k_quantized_scale , high_precision_flag
504
+ g , key_transposed_scaled , k_quantized_scale , high_precision_flag , k_block_shape
437
505
)
438
506
key_transposed_scaled = g .op ("Cast" , key_transposed_scaled , to_i = onnx_dtype_map ["Float" ])
439
507
mul_qk = g .op ("MatMul" , query_scaled , key_transposed_scaled )
@@ -463,7 +531,8 @@ def export_fp8_mha(
463
531
464
532
if not disable_fp8_mha :
465
533
# Softmax's output scale is hard coded to 1.0
466
- attn_weight = export_fp8 (g , attn_weight , 1.0 , high_precision_flag )
534
+ # We cannot do block quant for the softmax's output
535
+ attn_weight = export_fp8 (g , attn_weight , 1.0 , high_precision_flag , None )
467
536
attn_weight = g .op ("Cast" , attn_weight , to_i = onnx_dtype_map ["Float" ])
468
537
469
538
if dropout_p != 0 :
@@ -473,7 +542,7 @@ def export_fp8_mha(
473
542
g .op ("Constant" , value_t = torch .tensor (dropout_p , dtype = torch .float )),
474
543
)
475
544
if not disable_fp8_mha :
476
- value = export_fp8 (g , value , v_quantized_scale , high_precision_flag )
545
+ value = export_fp8 (g , value , v_quantized_scale , high_precision_flag , v_block_shape )
477
546
value = g .op ("Cast" , value , to_i = onnx_dtype_map ["Float" ])
478
547
return g .op (
479
548
"Cast" ,
0 commit comments