@@ -77,7 +77,7 @@ def _preprocess_onnx(
77
77
override_shapes : str ,
78
78
simplify : bool = False ,
79
79
quantize_mode : str = "int8" ,
80
- ) -> tuple [str , onnx .ModelProto , list [str ], bool , bool , bool , dict ]:
80
+ ) -> tuple [str , onnx .ModelProto , list [str ], bool , bool , bool , dict , dict ]:
81
81
logger .info (f"Preprocessing the model { onnx_path } " )
82
82
intermediate_generated_files = []
83
83
output_dir = os .path .dirname (output_path )
@@ -176,13 +176,14 @@ def _preprocess_onnx(
176
176
intermediate_generated_files .append (onnx_path )
177
177
178
178
# If custom op precisions are given, add Cast or Q/DQ where appropriate.
179
+ custom_ops_to_cast = {}
179
180
custom_ops_to_quantize = {}
180
181
if trt_plugins_precision :
181
182
custom_ops_to_cast , custom_ops_to_quantize = interpret_trt_plugins_precision_flag (
182
183
onnx_model , trt_plugins_precision , quantize_mode
183
184
)
184
- if custom_ops_to_cast :
185
- onnx_model = cast_custom_ops (onnx_model , custom_ops_to_cast )
185
+ if custom_ops_to_cast . get ( "fp16" , {}) :
186
+ onnx_model = cast_custom_ops (onnx_model , custom_ops_to_cast [ "fp16" ] )
186
187
onnx_path = os .path .join (output_dir , f"{ model_name } _castFP16.onnx" )
187
188
save_onnx (onnx_model , onnx_path , use_external_data_format )
188
189
logger .info (f"Model is cloned to { onnx_path } after casting tensors to FP16" )
@@ -195,6 +196,7 @@ def _preprocess_onnx(
195
196
has_custom_op ,
196
197
has_dds_op ,
197
198
use_external_data_format ,
199
+ custom_ops_to_cast .get ("fp32" , {}),
198
200
custom_ops_to_quantize ,
199
201
)
200
202
@@ -210,6 +212,7 @@ def quantize(
210
212
override_shapes : str | None = None ,
211
213
op_types_to_quantize : list [str ] | None = None ,
212
214
op_types_to_exclude : list [str ] | None = None ,
215
+ op_types_to_exclude_fp16 : list [str ] | None = None ,
213
216
nodes_to_quantize : list [str ] | None = None ,
214
217
nodes_to_exclude : list [str ] | None = None ,
215
218
use_external_data_format : bool = False ,
@@ -261,6 +264,9 @@ def quantize(
261
264
This flag does not support regular expression.
262
265
op_types_to_exclude:
263
266
List of op types to exclude from quantization. This flag does not support regular expression.
267
+ op_types_to_exclude_fp16:
268
+ List of op types to exclude from FP16 conversion.
269
+ This is only relevant if '--high_precision_dtype != fp32'.
264
270
nodes_to_quantize:
265
271
List of node names to quantize. If None (default), all supported nodes are quantized.
266
272
This flag supports regular expression.
@@ -402,6 +408,7 @@ def quantize(
402
408
has_custom_op ,
403
409
has_dds_op ,
404
410
use_external_data_format ,
411
+ custom_ops_to_cast_fp32 ,
405
412
custom_ops_to_quantize ,
406
413
) = _preprocess_onnx (
407
414
onnx_path ,
@@ -416,6 +423,16 @@ def quantize(
416
423
)
417
424
trt_plugins = update_trt_ep_support (calibration_eps , has_dds_op , has_custom_op , trt_plugins ) # type: ignore[arg-type]
418
425
426
+ # Update list with op types to exclude from FP16/BF16 conversion
427
+ op_types_to_exclude_fp16 = list (
428
+ dict .fromkeys ((op_types_to_exclude_fp16 or []) + list (custom_ops_to_cast_fp32 .keys ()))
429
+ )
430
+ if high_precision_dtype == "fp32" and op_types_to_exclude_fp16 :
431
+ logger .warning (
432
+ "Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
433
+ "Since the model won't be converted to a lower precision, this flag is void."
434
+ )
435
+
419
436
# Use random scales if calibration data is not supplied
420
437
if calibration_data is None :
421
438
calibration_data_reader = RandomDataProvider (onnx_path , calibration_shapes )
@@ -457,6 +474,7 @@ def quantize(
457
474
calibration_eps = calibration_eps ,
458
475
op_types_to_quantize = op_types_to_quantize ,
459
476
op_types_to_exclude = op_types_to_exclude ,
477
+ op_types_to_exclude_fp16 = op_types_to_exclude_fp16 ,
460
478
nodes_to_quantize = nodes_to_quantize ,
461
479
nodes_to_exclude = nodes_to_exclude ,
462
480
use_external_data_format = use_external_data_format ,
0 commit comments