@@ -77,7 +77,7 @@ def _preprocess_onnx(
7777 override_shapes : str ,
7878 simplify : bool = False ,
7979 quantize_mode : str = "int8" ,
80- ) -> tuple [str , onnx .ModelProto , list [str ], bool , bool , bool , dict ]:
80+ ) -> tuple [str , onnx .ModelProto , list [str ], bool , bool , bool , dict , dict ]:
8181 logger .info (f"Preprocessing the model { onnx_path } " )
8282 intermediate_generated_files = []
8383 output_dir = os .path .dirname (output_path )
@@ -176,13 +176,14 @@ def _preprocess_onnx(
176176 intermediate_generated_files .append (onnx_path )
177177
178178 # If custom op precisions are given, add Cast or Q/DQ where appropriate.
179+ custom_ops_to_cast = {}
179180 custom_ops_to_quantize = {}
180181 if trt_plugins_precision :
181182 custom_ops_to_cast , custom_ops_to_quantize = interpret_trt_plugins_precision_flag (
182183 onnx_model , trt_plugins_precision , quantize_mode
183184 )
184- if custom_ops_to_cast :
185- onnx_model = cast_custom_ops (onnx_model , custom_ops_to_cast )
185+ if custom_ops_to_cast . get ( "fp16" , {}) :
186+ onnx_model = cast_custom_ops (onnx_model , custom_ops_to_cast [ "fp16" ] )
186187 onnx_path = os .path .join (output_dir , f"{ model_name } _castFP16.onnx" )
187188 save_onnx (onnx_model , onnx_path , use_external_data_format )
188189 logger .info (f"Model is cloned to { onnx_path } after casting tensors to FP16" )
@@ -195,6 +196,7 @@ def _preprocess_onnx(
195196 has_custom_op ,
196197 has_dds_op ,
197198 use_external_data_format ,
199+ custom_ops_to_cast .get ("fp32" , {}),
198200 custom_ops_to_quantize ,
199201 )
200202
@@ -210,6 +212,7 @@ def quantize(
210212 override_shapes : str | None = None ,
211213 op_types_to_quantize : list [str ] | None = None ,
212214 op_types_to_exclude : list [str ] | None = None ,
215+ op_types_to_exclude_fp16 : list [str ] | None = None ,
213216 nodes_to_quantize : list [str ] | None = None ,
214217 nodes_to_exclude : list [str ] | None = None ,
215218 use_external_data_format : bool = False ,
@@ -261,6 +264,9 @@ def quantize(
261264 This flag does not support regular expression.
262265 op_types_to_exclude:
263266 List of op types to exclude from quantization. This flag does not support regular expression.
267+ op_types_to_exclude_fp16:
268+ List of op types to exclude from FP16 conversion.
269+ This is only relevant if '--high_precision_dtype != fp32'.
264270 nodes_to_quantize:
265271 List of node names to quantize. If None (default), all supported nodes are quantized.
266272 This flag supports regular expression.
@@ -402,6 +408,7 @@ def quantize(
402408 has_custom_op ,
403409 has_dds_op ,
404410 use_external_data_format ,
411+ custom_ops_to_cast_fp32 ,
405412 custom_ops_to_quantize ,
406413 ) = _preprocess_onnx (
407414 onnx_path ,
@@ -416,6 +423,16 @@ def quantize(
416423 )
417424 trt_plugins = update_trt_ep_support (calibration_eps , has_dds_op , has_custom_op , trt_plugins ) # type: ignore[arg-type]
418425
426+ # Update list with op types to exclude from FP16/BF16 conversion
427+ op_types_to_exclude_fp16 = list (
428+ dict .fromkeys ((op_types_to_exclude_fp16 or []) + list (custom_ops_to_cast_fp32 .keys ()))
429+ )
430+ if high_precision_dtype == "fp32" and op_types_to_exclude_fp16 :
431+ logger .warning (
432+ "Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
433+ "Since the model won't be converted to a lower precision, this flag is void."
434+ )
435+
419436 # Use random scales if calibration data is not supplied
420437 if calibration_data is None :
421438 calibration_data_reader = RandomDataProvider (onnx_path , calibration_shapes )
@@ -457,6 +474,7 @@ def quantize(
457474 calibration_eps = calibration_eps ,
458475 op_types_to_quantize = op_types_to_quantize ,
459476 op_types_to_exclude = op_types_to_exclude ,
477+ op_types_to_exclude_fp16 = op_types_to_exclude_fp16 ,
460478 nodes_to_quantize = nodes_to_quantize ,
461479 nodes_to_exclude = nodes_to_exclude ,
462480 use_external_data_format = use_external_data_format ,
0 commit comments