@@ -81,7 +81,7 @@ def _preprocess_onnx(
81
81
override_shapes : str ,
82
82
simplify : bool = False ,
83
83
quantize_mode : str = "int8" ,
84
- ) -> tuple [str , onnx .ModelProto , list [str ], bool , bool , bool , dict ]:
84
+ ) -> tuple [str , onnx .ModelProto , list [str ], bool , bool , bool , dict , dict ]:
85
85
logger .info (f"Preprocessing the model { onnx_path } " )
86
86
intermediate_generated_files = []
87
87
output_dir = os .path .dirname (output_path )
@@ -180,13 +180,14 @@ def _preprocess_onnx(
180
180
intermediate_generated_files .append (onnx_path )
181
181
182
182
# If custom op precisions are given, add Cast or Q/DQ where appropriate.
183
+ custom_ops_to_cast = {}
183
184
custom_ops_to_quantize = {}
184
185
if trt_plugins_precision :
185
186
custom_ops_to_cast , custom_ops_to_quantize = interpret_trt_plugins_precision_flag (
186
187
onnx_model , trt_plugins_precision , quantize_mode
187
188
)
188
- if custom_ops_to_cast :
189
- onnx_model = cast_custom_ops (onnx_model , custom_ops_to_cast )
189
+ if custom_ops_to_cast . get ( "fp16" , {}) :
190
+ onnx_model = cast_custom_ops (onnx_model , custom_ops_to_cast [ "fp16" ] )
190
191
onnx_path = os .path .join (output_dir , f"{ model_name } _castFP16.onnx" )
191
192
save_onnx (onnx_model , onnx_path , use_external_data_format )
192
193
logger .info (f"Model is cloned to { onnx_path } after casting tensors to FP16" )
@@ -199,6 +200,7 @@ def _preprocess_onnx(
199
200
has_custom_op ,
200
201
has_dds_op ,
201
202
use_external_data_format ,
203
+ custom_ops_to_cast .get ("fp32" , {}),
202
204
custom_ops_to_quantize ,
203
205
)
204
206
@@ -214,6 +216,7 @@ def quantize(
214
216
override_shapes : str | None = None ,
215
217
op_types_to_quantize : list [str ] | None = None ,
216
218
op_types_to_exclude : list [str ] | None = None ,
219
+ op_types_to_exclude_fp16 : list [str ] | None = None ,
217
220
nodes_to_quantize : list [str ] | None = None ,
218
221
nodes_to_exclude : list [str ] | None = None ,
219
222
use_external_data_format : bool = False ,
@@ -265,6 +268,9 @@ def quantize(
265
268
This flag does not support regular expression.
266
269
op_types_to_exclude:
267
270
List of op types to exclude from quantization. This flag does not support regular expression.
271
+ op_types_to_exclude_fp16:
272
+ List of op types to exclude from FP16 conversion.
273
+ This is only relevant if '--high_precision_dtype != fp32'.
268
274
nodes_to_quantize:
269
275
List of node names to quantize. If None (default), all supported nodes are quantized.
270
276
This flag supports regular expression.
@@ -406,6 +412,7 @@ def quantize(
406
412
has_custom_op ,
407
413
has_dds_op ,
408
414
use_external_data_format ,
415
+ custom_ops_to_cast_fp32 ,
409
416
custom_ops_to_quantize ,
410
417
) = _preprocess_onnx (
411
418
onnx_path ,
@@ -420,6 +427,16 @@ def quantize(
420
427
)
421
428
trt_plugins = update_trt_ep_support (calibration_eps , has_dds_op , has_custom_op , trt_plugins ) # type: ignore[arg-type]
422
429
430
+ # Update list with op types to exclude from FP16/BF16 conversion
431
+ op_types_to_exclude_fp16 = list (
432
+ dict .fromkeys ((op_types_to_exclude_fp16 or []) + list (custom_ops_to_cast_fp32 .keys ()))
433
+ )
434
+ if high_precision_dtype == "fp32" and op_types_to_exclude_fp16 :
435
+ logger .warning (
436
+ "Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
437
+ "Since the model won't be converted to a lower precision, this flag is void."
438
+ )
439
+
423
440
# Use random scales if calibration data is not supplied
424
441
if calibration_data is None :
425
442
calibration_data_reader = RandomDataProvider (onnx_path , calibration_shapes )
@@ -461,6 +478,7 @@ def quantize(
461
478
calibration_eps = calibration_eps ,
462
479
op_types_to_quantize = op_types_to_quantize ,
463
480
op_types_to_exclude = op_types_to_exclude ,
481
+ op_types_to_exclude_fp16 = op_types_to_exclude_fp16 ,
464
482
nodes_to_quantize = nodes_to_quantize ,
465
483
nodes_to_exclude = nodes_to_exclude ,
466
484
use_external_data_format = use_external_data_format ,
0 commit comments