@@ -81,7 +81,7 @@ def _preprocess_onnx(
8181 override_shapes : str ,
8282 simplify : bool = False ,
8383 quantize_mode : str = "int8" ,
84- ) -> tuple [str , onnx .ModelProto , list [str ], bool , bool , bool , dict ]:
84+ ) -> tuple [str , onnx .ModelProto , list [str ], bool , bool , bool , dict , dict ]:
8585 logger .info (f"Preprocessing the model { onnx_path } " )
8686 intermediate_generated_files = []
8787 output_dir = os .path .dirname (output_path )
@@ -180,13 +180,14 @@ def _preprocess_onnx(
180180 intermediate_generated_files .append (onnx_path )
181181
182182 # If custom op precisions are given, add Cast or Q/DQ where appropriate.
183+ custom_ops_to_cast = {}
183184 custom_ops_to_quantize = {}
184185 if trt_plugins_precision :
185186 custom_ops_to_cast , custom_ops_to_quantize = interpret_trt_plugins_precision_flag (
186187 onnx_model , trt_plugins_precision , quantize_mode
187188 )
188- if custom_ops_to_cast :
189- onnx_model = cast_custom_ops (onnx_model , custom_ops_to_cast )
189+ if custom_ops_to_cast . get ( "fp16" , {}) :
190+ onnx_model = cast_custom_ops (onnx_model , custom_ops_to_cast [ "fp16" ] )
190191 onnx_path = os .path .join (output_dir , f"{ model_name } _castFP16.onnx" )
191192 save_onnx (onnx_model , onnx_path , use_external_data_format )
192193 logger .info (f"Model is cloned to { onnx_path } after casting tensors to FP16" )
@@ -199,6 +200,7 @@ def _preprocess_onnx(
199200 has_custom_op ,
200201 has_dds_op ,
201202 use_external_data_format ,
203+ custom_ops_to_cast .get ("fp32" , {}),
202204 custom_ops_to_quantize ,
203205 )
204206
@@ -214,6 +216,7 @@ def quantize(
214216 override_shapes : str | None = None ,
215217 op_types_to_quantize : list [str ] | None = None ,
216218 op_types_to_exclude : list [str ] | None = None ,
219+ op_types_to_exclude_fp16 : list [str ] | None = None ,
217220 nodes_to_quantize : list [str ] | None = None ,
218221 nodes_to_exclude : list [str ] | None = None ,
219222 use_external_data_format : bool = False ,
@@ -265,6 +268,9 @@ def quantize(
265268 This flag does not support regular expression.
266269 op_types_to_exclude:
267270 List of op types to exclude from quantization. This flag does not support regular expression.
271+ op_types_to_exclude_fp16:
272+ List of op types to exclude from FP16 conversion.
273+ This is only relevant if '--high_precision_dtype != fp32'.
268274 nodes_to_quantize:
269275 List of node names to quantize. If None (default), all supported nodes are quantized.
270276 This flag supports regular expression.
@@ -406,6 +412,7 @@ def quantize(
406412 has_custom_op ,
407413 has_dds_op ,
408414 use_external_data_format ,
415+ custom_ops_to_cast_fp32 ,
409416 custom_ops_to_quantize ,
410417 ) = _preprocess_onnx (
411418 onnx_path ,
@@ -420,6 +427,16 @@ def quantize(
420427 )
421428 trt_plugins = update_trt_ep_support (calibration_eps , has_dds_op , has_custom_op , trt_plugins ) # type: ignore[arg-type]
422429
430+ # Update list with op types to exclude from FP16/BF16 conversion
431+ op_types_to_exclude_fp16 = list (
432+ dict .fromkeys ((op_types_to_exclude_fp16 or []) + list (custom_ops_to_cast_fp32 .keys ()))
433+ )
434+ if high_precision_dtype == "fp32" and op_types_to_exclude_fp16 :
435+ logger .warning (
436+ "Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
437+ "Since the model won't be converted to a lower precision, this flag is void."
438+ )
439+
423440 # Use random scales if calibration data is not supplied
424441 if calibration_data is None :
425442 calibration_data_reader = RandomDataProvider (onnx_path , calibration_shapes )
@@ -461,6 +478,7 @@ def quantize(
461478 calibration_eps = calibration_eps ,
462479 op_types_to_quantize = op_types_to_quantize ,
463480 op_types_to_exclude = op_types_to_exclude ,
481+ op_types_to_exclude_fp16 = op_types_to_exclude_fp16 ,
464482 nodes_to_quantize = nodes_to_quantize ,
465483 nodes_to_exclude = nodes_to_exclude ,
466484 use_external_data_format = use_external_data_format ,
0 commit comments