diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ed724e55a..ecf8fb1d1 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,15 @@ Model Optimizer Changelog (Linux) ================================= +0.39 (2025-10-xx) +^^^^^^^^^^^^^^^^^ + +**Deprecations** + +**New Features** + +- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``. + 0.37 (2025-09-xx) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index 1be2e75df..55cca6ee5 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -90,28 +90,33 @@ def get_parser() -> argparse.ArgumentParser: argparser.add_argument( "--op_types_to_quantize", type=str, - default=[], nargs="+", help="A space-separated list of node types to quantize.", ) argparser.add_argument( "--op_types_to_exclude", type=str, - default=[], nargs="+", help="A space-separated list of node types to exclude from quantization.", ) + argparser.add_argument( + "--op_types_to_exclude_fp16", + type=str, + nargs="+", + help=( + "A space-separated list of node types to exclude from FP16/BF16 conversion. " + "Relevant when --high_precision_dtype is 'fp16' or 'bf16'." + ), + ) argparser.add_argument( "--nodes_to_quantize", type=str, - default=[], nargs="+", help="A space-separated list of node names to quantize. Regular expressions are supported.", ) argparser.add_argument( "--nodes_to_exclude", type=str, - default=[], nargs="+", help="A space-separated list of node names to exclude from quantization. Regular expressions are supported.", ) @@ -274,6 +279,7 @@ def main(): override_shapes=args.override_shapes, op_types_to_quantize=args.op_types_to_quantize, op_types_to_exclude=args.op_types_to_exclude, + op_types_to_exclude_fp16=args.op_types_to_exclude_fp16, nodes_to_quantize=args.nodes_to_quantize, nodes_to_exclude=args.nodes_to_exclude, use_external_data_format=args.use_external_data_format, diff --git a/modelopt/onnx/quantization/fp8.py b/modelopt/onnx/quantization/fp8.py index 4ee9c4f67..1ef3c9799 100755 --- a/modelopt/onnx/quantization/fp8.py +++ b/modelopt/onnx/quantization/fp8.py @@ -168,6 +168,7 @@ def quantize( calibration_eps: list[str] = ["cpu", "cuda:0", "trt"], op_types_to_quantize: list[str] | None = None, op_types_to_exclude: list[str] | None = None, + op_types_to_exclude_fp16: list[str] | None = None, nodes_to_quantize: list[str] | None = None, nodes_to_exclude: list[str] | None = None, use_external_data_format: bool = False, @@ -318,7 +319,7 @@ def quantize( onnx_model = convert_to_f16( onnx_model, keep_io_types=not direct_io_types, - op_block_list=["Resize"], + op_block_list=op_types_to_exclude_fp16 or [], low_precision_type=high_precision_dtype, trt_plugins=trt_extra_plugin_lib_paths, ) diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index 156e798b0..5a878fb76 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -119,6 +119,7 @@ def quantize( calibration_eps: list[str] = ["cpu", "cuda:0", "trt"], op_types_to_quantize: list[str] | None = None, op_types_to_exclude: list[str] | None = None, + op_types_to_exclude_fp16: list[str] | None = None, nodes_to_quantize: list[str] | None = None, nodes_to_exclude: list[str] | None = None, use_external_data_format: bool = False, @@ -279,6 +280,7 @@ def quantize( onnx_model = convert_to_f16( onnx_model, keep_io_types=not direct_io_types, + op_block_list=op_types_to_exclude_fp16 or [], low_precision_type=high_precision_dtype, trt_plugins=trt_extra_plugin_lib_paths, ) diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 9f6f9cae9..daf785326 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -81,7 +81,7 @@ def _preprocess_onnx( override_shapes: str, simplify: bool = False, quantize_mode: str = "int8", -) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict]: +) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]: logger.info(f"Preprocessing the model {onnx_path}") intermediate_generated_files = [] output_dir = os.path.dirname(output_path) @@ -180,13 +180,14 @@ def _preprocess_onnx( intermediate_generated_files.append(onnx_path) # If custom op precisions are given, add Cast or Q/DQ where appropriate. + custom_ops_to_cast = {} custom_ops_to_quantize = {} if trt_plugins_precision: custom_ops_to_cast, custom_ops_to_quantize = interpret_trt_plugins_precision_flag( onnx_model, trt_plugins_precision, quantize_mode ) - if custom_ops_to_cast: - onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast) + if custom_ops_to_cast.get("fp16", {}): + onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast["fp16"]) onnx_path = os.path.join(output_dir, f"{model_name}_castFP16.onnx") save_onnx(onnx_model, onnx_path, use_external_data_format) logger.info(f"Model is cloned to {onnx_path} after casting tensors to FP16") @@ -199,6 +200,7 @@ def _preprocess_onnx( has_custom_op, has_dds_op, use_external_data_format, + custom_ops_to_cast.get("fp32", {}), custom_ops_to_quantize, ) @@ -214,6 +216,7 @@ def quantize( override_shapes: str | None = None, op_types_to_quantize: list[str] | None = None, op_types_to_exclude: list[str] | None = None, + op_types_to_exclude_fp16: list[str] | None = None, nodes_to_quantize: list[str] | None = None, nodes_to_exclude: list[str] | None = None, use_external_data_format: bool = False, @@ -265,6 +268,9 @@ def quantize( This flag does not support regular expression. op_types_to_exclude: List of op types to exclude from quantization. This flag does not support regular expression. + op_types_to_exclude_fp16: + List of op types to exclude from FP16 conversion. + This is only relevant if '--high_precision_dtype != fp32'. nodes_to_quantize: List of node names to quantize. If None (default), all supported nodes are quantized. This flag supports regular expression. @@ -406,6 +412,7 @@ def quantize( has_custom_op, has_dds_op, use_external_data_format, + custom_ops_to_cast_fp32, custom_ops_to_quantize, ) = _preprocess_onnx( onnx_path, @@ -420,6 +427,16 @@ def quantize( ) trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type] + # Update list with op types to exclude from FP16/BF16 conversion + op_types_to_exclude_fp16 = list( + dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys())) + ) + if high_precision_dtype == "fp32" and op_types_to_exclude_fp16: + logger.warning( + "Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. " + "Since the model won't be converted to a lower precision, this flag is void." + ) + # Use random scales if calibration data is not supplied if calibration_data is None: calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes) @@ -461,6 +478,7 @@ def quantize( calibration_eps=calibration_eps, op_types_to_quantize=op_types_to_quantize, op_types_to_exclude=op_types_to_exclude, + op_types_to_exclude_fp16=op_types_to_exclude_fp16, nodes_to_quantize=nodes_to_quantize, nodes_to_exclude=nodes_to_exclude, use_external_data_format=use_external_data_format, diff --git a/modelopt/onnx/trt_utils.py b/modelopt/onnx/trt_utils.py index 85312ecd4..2231ccd00 100644 --- a/modelopt/onnx/trt_utils.py +++ b/modelopt/onnx/trt_utils.py @@ -367,10 +367,12 @@ def interpret_trt_plugins_precision_flag( if trt_plugin_precision.count(":") == 1: if precision not in supported_precisions: logger.warning(f"Precision {precision} is not supported. Skipping.") - if precision == "fp16": - custom_ops_to_cast[op_type] = { - "inp": list(range(num_inps)), - "out": list(range(num_outs)), + if precision in ["fp16", "fp32"]: + custom_ops_to_cast[precision] = { + op_type: { + "inp": list(range(num_inps)), + "out": list(range(num_outs)), + } } if precision in ["int8", "fp8"]: if precision != quantize_mode: @@ -408,10 +410,14 @@ def interpret_trt_plugins_precision_flag( f"Setting the custom op precision to be the same as quantize mode." ) - # Will cast the inputs to FP16 and the outputs back to FP32 - inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == "fp16"] - out_precision_cast = [i for i, p in enumerate(out_precision) if p in ["fp16", "fp32"]] - custom_ops_to_cast[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast} + # Will cast the inputs to FP16/FP32 and the outputs back to FP32 + for precision in ["fp16", "fp32"]: + inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision] + out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision] + if inp_precision_cast: + custom_ops_to_cast[precision] = { + op_type: {"inp": inp_precision_cast, "out": out_precision_cast} + } # Will add Q/DQ nodes in the requested I/O indices inp_precision_quant = [i for i, p in enumerate(inp_precision) if p in ["int8", "fp8"]]