Skip to content

Commit 1ec0af0

Browse files
committed
Added CodeRabbit suggestions
Signed-off-by: gcunhase <[email protected]>
1 parent fe837a8 commit 1ec0af0

File tree

4 files changed

+9
-8
lines changed

4 files changed

+9
-8
lines changed

modelopt/onnx/quantization/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def get_parser() -> argparse.ArgumentParser:
107107
default=[],
108108
nargs="+",
109109
help=(
110-
"A space-separated list of node types to exclude from FP16 conversion. "
111-
"This is only relevant if '--high_precision_dtype != fp32'."
110+
"A space-separated list of node types to exclude from FP16/BF16 conversion. "
111+
"Relevant when --high_precision_dtype is 'fp16' or 'bf16'."
112112
),
113113
)
114114
argparser.add_argument(

modelopt/onnx/quantization/fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def quantize(
168168
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
169169
op_types_to_quantize: list[str] | None = None,
170170
op_types_to_exclude: list[str] | None = None,
171-
op_types_to_exclude_fp16: list[str] = [],
171+
op_types_to_exclude_fp16: list[str] | None = None,
172172
nodes_to_quantize: list[str] | None = None,
173173
nodes_to_exclude: list[str] | None = None,
174174
use_external_data_format: bool = False,
@@ -319,7 +319,7 @@ def quantize(
319319
onnx_model = convert_to_f16(
320320
onnx_model,
321321
keep_io_types=not direct_io_types,
322-
op_block_list=op_types_to_exclude_fp16,
322+
op_block_list=op_types_to_exclude_fp16 or [],
323323
low_precision_type=high_precision_dtype,
324324
trt_plugins=trt_extra_plugin_lib_paths,
325325
)

modelopt/onnx/quantization/int8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def quantize(
119119
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
120120
op_types_to_quantize: list[str] | None = None,
121121
op_types_to_exclude: list[str] | None = None,
122-
op_types_to_exclude_fp16: list[str] = [],
122+
op_types_to_exclude_fp16: list[str] | None = None,
123123
nodes_to_quantize: list[str] | None = None,
124124
nodes_to_exclude: list[str] | None = None,
125125
use_external_data_format: bool = False,
@@ -280,7 +280,7 @@ def quantize(
280280
onnx_model = convert_to_f16(
281281
onnx_model,
282282
keep_io_types=not direct_io_types,
283-
op_block_list=op_types_to_exclude_fp16,
283+
op_block_list=op_types_to_exclude_fp16 or [],
284284
low_precision_type=high_precision_dtype,
285285
trt_plugins=trt_extra_plugin_lib_paths,
286286
)

modelopt/onnx/quantization/quantize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,9 @@ def quantize(
426426
quantize_mode,
427427
)
428428
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]
429-
op_types_to_exclude_fp16 = op_types_to_exclude_fp16 or []
430-
op_types_to_exclude_fp16.extend(list(custom_ops_to_cast_fp32.keys()))
429+
op_types_to_exclude_fp16 = list(
430+
dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys()))
431+
)
431432

432433
# Use random scales if calibration data is not supplied
433434
if calibration_data is None:

0 commit comments

Comments
 (0)