Skip to content

Commit 62564db

Browse files
committed
Insert cast nodes for 'FP32 required' custom ops
Signed-off-by: gcunhase <[email protected]>
1 parent 4ff8fc9 commit 62564db

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

modelopt/onnx/quantization/fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def quantize(
178178
passes: list[str] = ["concat_elimination"],
179179
log_level: str = "INFO",
180180
calibrate_per_node: bool = False,
181+
custom_ops_to_cast_fp32: list[str] = [],
181182
custom_ops_to_quantize: list[str] = [],
182183
direct_io_types: bool = False,
183184
**kwargs,
@@ -318,7 +319,7 @@ def quantize(
318319
onnx_model = convert_to_f16(
319320
onnx_model,
320321
keep_io_types=not direct_io_types,
321-
op_block_list=["Resize"],
322+
op_block_list=custom_ops_to_cast_fp32,
322323
low_precision_type=high_precision_dtype,
323324
trt_plugins=trt_extra_plugin_lib_paths,
324325
)

modelopt/onnx/quantization/int8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def quantize(
128128
passes: list[str] = ["concat_elimination"],
129129
log_level: str = "INFO",
130130
calibrate_per_node: bool = False,
131+
custom_ops_to_cast_fp32: list[str] = [],
131132
custom_ops_to_quantize: list[str] = [],
132133
direct_io_types: bool = False,
133134
**kwargs,
@@ -279,6 +280,7 @@ def quantize(
279280
onnx_model = convert_to_f16(
280281
onnx_model,
281282
keep_io_types=not direct_io_types,
283+
op_block_list=custom_ops_to_cast_fp32,
282284
low_precision_type=high_precision_dtype,
283285
trt_plugins=trt_extra_plugin_lib_paths,
284286
)

modelopt/onnx/quantization/quantize.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _preprocess_onnx(
8181
override_shapes: str,
8282
simplify: bool = False,
8383
quantize_mode: str = "int8",
84-
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict]:
84+
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]:
8585
logger.info(f"Preprocessing the model {onnx_path}")
8686
intermediate_generated_files = []
8787
output_dir = os.path.dirname(output_path)
@@ -180,13 +180,14 @@ def _preprocess_onnx(
180180
intermediate_generated_files.append(onnx_path)
181181

182182
# If custom op precisions are given, add Cast or Q/DQ where appropriate.
183+
custom_ops_to_cast = {}
183184
custom_ops_to_quantize = {}
184185
if trt_plugins_precision:
185186
custom_ops_to_cast, custom_ops_to_quantize = interpret_trt_plugins_precision_flag(
186187
onnx_model, trt_plugins_precision, quantize_mode
187188
)
188-
if custom_ops_to_cast:
189-
onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast)
189+
if custom_ops_to_cast.get("fp16", {}):
190+
onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast["fp16"])
190191
onnx_path = os.path.join(output_dir, f"{model_name}_castFP16.onnx")
191192
save_onnx(onnx_model, onnx_path, use_external_data_format)
192193
logger.info(f"Model is cloned to {onnx_path} after casting tensors to FP16")
@@ -199,6 +200,7 @@ def _preprocess_onnx(
199200
has_custom_op,
200201
has_dds_op,
201202
use_external_data_format,
203+
custom_ops_to_cast.get("fp32", {}),
202204
custom_ops_to_quantize,
203205
)
204206

@@ -406,6 +408,7 @@ def quantize(
406408
has_custom_op,
407409
has_dds_op,
408410
use_external_data_format,
411+
custom_ops_to_cast_fp32,
409412
custom_ops_to_quantize,
410413
) = _preprocess_onnx(
411414
onnx_path,
@@ -471,6 +474,7 @@ def quantize(
471474
passes=passes,
472475
log_level=log_level,
473476
calibrate_per_node=calibrate_per_node,
477+
custom_ops_to_cast_fp32=list(custom_ops_to_cast_fp32.keys()),
474478
custom_ops_to_quantize=list(custom_ops_to_quantize.keys()),
475479
direct_io_types=direct_io_types,
476480
**kwargs,

modelopt/onnx/trt_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,12 @@ def interpret_trt_plugins_precision_flag(
367367
if trt_plugin_precision.count(":") == 1:
368368
if precision not in supported_precisions:
369369
logger.warning(f"Precision {precision} is not supported. Skipping.")
370-
if precision == "fp16":
371-
custom_ops_to_cast[op_type] = {
372-
"inp": list(range(num_inps)),
373-
"out": list(range(num_outs)),
370+
if precision in ["fp16", "fp32"]:
371+
custom_ops_to_cast[precision] = {
372+
op_type: {
373+
"inp": list(range(num_inps)),
374+
"out": list(range(num_outs)),
375+
}
374376
}
375377
if precision in ["int8", "fp8"]:
376378
if precision != quantize_mode:
@@ -408,10 +410,14 @@ def interpret_trt_plugins_precision_flag(
408410
f"Setting the custom op precision to be the same as quantize mode."
409411
)
410412

411-
# Will cast the inputs to FP16 and the outputs back to FP32
412-
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == "fp16"]
413-
out_precision_cast = [i for i, p in enumerate(out_precision) if p in ["fp16", "fp32"]]
414-
custom_ops_to_cast[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
413+
# Will cast the inputs to FP16/FP32 and the outputs back to FP32
414+
for precision in ["fp16", "fp32"]:
415+
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
416+
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
417+
if inp_precision_cast:
418+
custom_ops_to_cast[precision] = {
419+
op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
420+
}
415421

416422
# Will add Q/DQ nodes in the requested I/O indices
417423
inp_precision_quant = [i for i, p in enumerate(inp_precision) if p in ["int8", "fp8"]]

0 commit comments

Comments
 (0)