@@ -234,6 +234,34 @@ def _fp8_quantize(
234234 )
235235 return q_op
236236
237+ def _fp8_block_quantize (
238+ g : torch .onnx ._internal .jit_utils .GraphContext ,
239+ inputs : torch .Value ,
240+ trt_high_precision_dtype : str ,
241+ block_sizes : list ,
242+ ):
243+ """Helper Function for Quantization."""
244+ output_shape = sym_help ._get_tensor_sizes (inputs )
245+
246+ # TRT StronglyType only supports FP16 QDQs
247+ # custom ops, so cast the input if needed.
248+ input_type = inputs .type ().scalarType ()
249+ assert trt_high_precision_dtype in (input_type , "Float" ), (
250+ "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
251+ )
252+ if trt_high_precision_dtype != input_type :
253+ inputs = g .op ("Cast" , inputs , to_i = onnx_dtype_map [trt_high_precision_dtype ])
254+ quantized_output , scales_output = g .op (
255+ "trt::TRT_DynamicQuantize" ,
256+ inputs ,
257+ block_shape_i = block_sizes ,
258+ outputs = 2 ,
259+ scale_type_i = onnx_dtype_map ["Float" ],
260+ output_type_i = onnx_dtype_map ["Float8" ],
261+ )
262+ quantized_output .setType (inputs .type ().with_dtype (torch .uint8 ).with_sizes (output_shape ))
263+ return quantized_output , scales_output
264+
237265
238266def _fp8_dequantize (
239267 g : torch .onnx ._internal .jit_utils .GraphContext ,
@@ -261,21 +289,58 @@ def _fp8_dequantize(
261289 out = g .op ("Cast" , out , to_i = onnx_dtype_map [otype ]) # type: ignore[index]
262290 return out
263291
292+ def _fp8_block_dequantize (
293+ g : torch .onnx ._internal .jit_utils .GraphContext ,
294+ inputs : torch .Value ,
295+ scales : torch .Value ,
296+ trt_high_precision_dtype : str ,
297+ otype : str | None = None ,
298+ block_sizes : list = [1 ,1 ,128 ,1 ]
299+ ):
300+ """Helper Function for Dequantization."""
301+ output_shape = sym_help ._get_tensor_sizes (inputs )
302+ assert trt_high_precision_dtype in (otype , "Float" ), (
303+ "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float."
304+ )
305+ out = g .op (
306+ "trt::TRT_BlockDequantize" ,
307+ inputs ,
308+ scales ,
309+ block_shape_i = block_sizes ,
310+ ).setType (
311+ inputs .type ().with_dtype (torch_dtype_map [trt_high_precision_dtype ]).with_sizes (output_shape )
312+ )
313+
314+ # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT
315+ # custom ops, so cast the output if needed.
316+ if trt_high_precision_dtype != otype :
317+ out = g .op ("Cast" , out , to_i = onnx_dtype_map [otype ]) # type: ignore[index]
318+ return out
319+
264320
265321def export_fp8 (
266322 g : torch .onnx ._internal .jit_utils .GraphContext ,
267323 inputs : torch .Value ,
268- amax : float ,
324+ amax : float | None ,
269325 trt_high_precision_dtype : str | None ,
326+ block_sizes : list | None ,
270327):
271328 """Export quantized model to FP8 ONNX."""
272329 scale = 1.0 if amax is None else 448.0 / float (amax )
273330 otype = inputs .type ().scalarType ()
274331 if trt_high_precision_dtype is None :
275332 trt_high_precision_dtype = otype
333+ if not block_sizes :
334+ q_tensor = _fp8_quantize (g , inputs , 1.0 / scale , trt_high_precision_dtype )
335+ return _fp8_dequantize (g , q_tensor , 1.0 / scale , trt_high_precision_dtype , otype )
336+ else :
337+ q_tensor , scales_output = _fp8_block_quantize (
338+ g , inputs , trt_high_precision_dtype , block_sizes
339+ )
340+ return _fp8_block_dequantize (
341+ g , q_tensor , scales_output , trt_high_precision_dtype , otype , block_sizes
342+ )
276343
277- q_tensor = _fp8_quantize (g , inputs , 1.0 / scale , trt_high_precision_dtype )
278- return _fp8_dequantize (g , q_tensor , 1.0 / scale , trt_high_precision_dtype , otype )
279344
280345
281346def scaled_dot_product_attention (
@@ -360,11 +425,14 @@ def export_fp8_mha(
360425 dropout_p : float = 0.0 ,
361426 is_causal : bool = False ,
362427 scale : torch ._C .Value | None = None ,
363- q_quantized_scale : float = 1.0 ,
364- k_quantized_scale : float = 1.0 ,
365- v_quantized_scale : float = 1.0 ,
428+ q_quantized_scale : float | None = 1.0 ,
429+ k_quantized_scale : float | None = 1.0 ,
430+ v_quantized_scale : float | None = 1.0 ,
366431 high_precision_flag : str = "Half" ,
367432 disable_fp8_mha : bool = True ,
433+ q_block_shape : list | None = None ,
434+ k_block_shape : list | None = None ,
435+ v_block_shape : list | None = None ,
368436):
369437 r"""Export quantized fMHA to FP8 ONNX.
370438
@@ -430,10 +498,10 @@ def export_fp8_mha(
430498 v_input_dtype = value .type ().scalarType ()
431499 if {q_input_dtype , k_input_dtype , v_input_dtype } != {high_precision_flag }:
432500 raise ValueError ("The quantized MHA must have 16-bit inputs." )
433- query_scaled = export_fp8 (g , query_scaled , q_quantized_scale , high_precision_flag )
501+ query_scaled = export_fp8 (g , query_scaled , q_quantized_scale , high_precision_flag , q_block_shape )
434502 query_scaled = g .op ("Cast" , query_scaled , to_i = onnx_dtype_map ["Float" ])
435503 key_transposed_scaled = export_fp8 (
436- g , key_transposed_scaled , k_quantized_scale , high_precision_flag
504+ g , key_transposed_scaled , k_quantized_scale , high_precision_flag , k_block_shape
437505 )
438506 key_transposed_scaled = g .op ("Cast" , key_transposed_scaled , to_i = onnx_dtype_map ["Float" ])
439507 mul_qk = g .op ("MatMul" , query_scaled , key_transposed_scaled )
@@ -463,7 +531,8 @@ def export_fp8_mha(
463531
464532 if not disable_fp8_mha :
465533 # Softmax's output scale is hard coded to 1.0
466- attn_weight = export_fp8 (g , attn_weight , 1.0 , high_precision_flag )
534+ # We cannot do block quant for the softmax's output
535+ attn_weight = export_fp8 (g , attn_weight , 1.0 , high_precision_flag , None )
467536 attn_weight = g .op ("Cast" , attn_weight , to_i = onnx_dtype_map ["Float" ])
468537
469538 if dropout_p != 0 :
@@ -473,7 +542,7 @@ def export_fp8_mha(
473542 g .op ("Constant" , value_t = torch .tensor (dropout_p , dtype = torch .float )),
474543 )
475544 if not disable_fp8_mha :
476- value = export_fp8 (g , value , v_quantized_scale , high_precision_flag )
545+ value = export_fp8 (g , value , v_quantized_scale , high_precision_flag , v_block_shape )
477546 value = g .op ("Cast" , value , to_i = onnx_dtype_map ["Float" ])
478547 return g .op (
479548 "Cast" ,
0 commit comments