Skip to content

Commit 69cb674

Browse files
committed
Add support for FP16-only custom ops
Signed-off-by: gcunhase <[email protected]>
1 parent 9077a03 commit 69cb674

File tree

5 files changed

+48
-4
lines changed

5 files changed

+48
-4
lines changed

modelopt/onnx/autocast/__main__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,20 @@ def get_parser() -> argparse.ArgumentParser:
162162
"libraries are in the PATH or LD_LIBRARY_PATH variables."
163163
),
164164
)
165+
parser.add_argument(
166+
"--trt_plugins_precision",
167+
type=str,
168+
default=[],
169+
nargs="+",
170+
help=(
171+
"A space-separated list indicating the precision for each custom op. "
172+
"Each item should have the format <op_type>:<precision> (all inputs and outputs have the same precision) "
173+
"or <op_type>:[<inp1_precision>,<inp2_precision>,...]:[<out1_precision>,<out2_precision>,...] "
174+
"(inputs and outputs can have different precisions), where precision can be fp32 (default), "
175+
"fp16, int8, or fp8. Note that int8/fp8 should be the same as the quantization mode. "
176+
"For example: op_type_1:fp16 op_type_2:[int8,fp32]:[int8]."
177+
),
178+
)
165179

166180
return parser
167181

@@ -192,6 +206,7 @@ def main(argv=None):
192206
init_conversion_max_bytes=args.init_conversion_max_bytes,
193207
providers=args.providers,
194208
trt_plugins=args.trt_plugins,
209+
trt_plugins_precision=args.trt_plugins_precision,
195210
max_depth_of_reduction=args.max_depth_of_reduction,
196211
)
197212

modelopt/onnx/autocast/convert.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def convert_to_mixed_precision(
5858
init_conversion_max_bytes: int | None = None,
5959
providers: list[str] = ["cpu"],
6060
trt_plugins: list[str] = [],
61+
trt_plugins_precision: list[str] = [],
6162
max_depth_of_reduction: int | None = None,
6263
) -> onnx.ModelProto:
6364
"""Convert model to mixed precision.
@@ -78,6 +79,7 @@ def convert_to_mixed_precision(
7879
runtime.
7980
providers: List of ORT execution providers.
8081
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
82+
trt_plugins_precision: List indicating the precision for each custom op.
8183
max_depth_of_reduction: Maximum depth of reduction for node classification.
8284
8385
Returns:
@@ -92,7 +94,11 @@ def convert_to_mixed_precision(
9294
# Otherwise, prefer to keep the original opset version unless it's very old
9395
min_opset = 22 if low_precision_type == "bf16" else 13
9496
graph_sanitizer = GraphSanitizer(
95-
model, min_opset, trt_plugins=trt_plugins, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT
97+
model,
98+
min_opset,
99+
trt_plugins=trt_plugins,
100+
trt_plugins_precision=trt_plugins_precision,
101+
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
96102
)
97103
graph_sanitizer.sanitize()
98104
model = graph_sanitizer.model
@@ -118,6 +124,7 @@ def convert_to_mixed_precision(
118124
init_max=init_max,
119125
custom_rule=custom_rule,
120126
max_depth_of_reduction=max_depth_of_reduction,
127+
custom_ops_low_precision_nodes=graph_sanitizer.custom_ops_low_precision_nodes or [],
121128
)
122129

123130
precision_converter = PrecisionConverter(

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import modelopt.onnx.autocast.utils as utils
2424
import modelopt.onnx.utils as onnx_utils
2525
from modelopt.onnx.autocast.logging_config import logger
26+
from modelopt.onnx.quantization.graph_utils import cast_custom_ops
27+
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag
2628

2729

2830
class GraphSanitizer:
@@ -34,6 +36,7 @@ def __init__(
3436
min_opset: int = 13,
3537
max_ir_version: int | None = None,
3638
trt_plugins: list[str] | None = [],
39+
trt_plugins_precision: list[str] | None = [],
3740
) -> None:
3841
"""Initialize GraphSanitizer.
3942
@@ -48,7 +51,9 @@ def __init__(
4851
self.max_ir_version = max_ir_version
4952
self.standard_ops = {schema.name for schema in onnx.defs.get_all_schemas()}
5053
self.custom_ops = None
54+
self.custom_ops_low_precision_nodes = []
5155
self.trt_plugins = trt_plugins
56+
self.trt_plugins_precision = trt_plugins_precision or []
5257

5358
def sanitize(self) -> None:
5459
"""Sanitize the model graph.
@@ -67,6 +72,7 @@ def sanitize(self) -> None:
6772
self.cleanup_model()
6873
self.set_ir_version(self.max_ir_version)
6974
self.convert_fp64_to_fp32()
75+
self.ensure_custom_ops_precision()
7076

7177
def convert_fp64_to_fp32(self) -> None:
7278
"""Convert FP64 initializers, I/O types, and specific nodes to FP32."""
@@ -88,6 +94,19 @@ def convert_fp64_to_fp32(self) -> None:
8894
logger.info("Converted FP64 initializers, I/O types, and nodes to FP32")
8995
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)
9096

97+
def ensure_custom_ops_precision(self) -> None:
98+
"""Ensure that custom ops run in the requested precision."""
99+
custom_ops_to_cast, _ = interpret_trt_plugins_precision_flag(
100+
self.model,
101+
self.trt_plugins_precision,
102+
)
103+
if custom_ops_to_cast.get("fp16", {}):
104+
self.model = cast_custom_ops(self.model, custom_ops_to_cast["fp16"])
105+
self.custom_ops_low_precision_nodes = [
106+
n.name for n in self.model.graph.node if n.op_type in custom_ops_to_cast["fp16"]
107+
]
108+
logger.info("Ensured custom ops precision")
109+
91110
def find_custom_nodes(self) -> None:
92111
"""Find custom nodes in the model.
93112

modelopt/onnx/autocast/nodeclassifier.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def __init__(
360360
data_max: float | None = 1000.0,
361361
init_max: float | None = np.finfo(np.float16).max,
362362
max_depth_of_reduction: int | None = None,
363+
custom_ops_low_precision_nodes: list[str] | None = None,
363364
):
364365
"""Initialize the node classifier.
365366
@@ -375,6 +376,7 @@ def __init__(
375376
data_max: Maximum absolute value allowed for node I/O.
376377
init_max: Maximum absolute value allowed for initializers.
377378
max_depth_of_reduction: Maximum depth of reduction allowed in low precision.
379+
custom_ops_low_precision_nodes: List of custom op node names to convert to low precision.
378380
"""
379381
self.model = model
380382
self.node_to_init_map = node_to_init_map
@@ -387,6 +389,7 @@ def __init__(
387389
self.data_max = data_max
388390
self.init_max = init_max
389391
self.max_depth_of_reduction = max_depth_of_reduction
392+
self.custom_ops_low_precision_nodes = custom_ops_low_precision_nodes
390393

391394
def _gen_exclude_node_rules(self, reference_data):
392395
"""Generate list of rules for blocking nodes from precision conversion.
@@ -446,11 +449,11 @@ def run(self, ref_outputs_dict=None):
446449
"""
447450
exclude_node_rules = self._gen_exclude_node_rules(ref_outputs_dict)
448451
include_node_rules = self._gen_include_node_rules()
449-
low_precision_nodes = []
452+
low_precision_nodes = self.custom_ops_low_precision_nodes or []
450453
high_precision_nodes = []
451454
for node in self.model.graph.node:
452455
# If any condition is met - node will be executed in high precision
453-
if any(rule.check(node) for rule in exclude_node_rules) and not any(
456+
if node.name not in low_precision_nodes and any(rule.check(node) for rule in exclude_node_rules) and not any(
454457
rule.check(node) for rule in include_node_rules
455458
):
456459
high_precision_nodes.append(node.name)

modelopt/onnx/trt_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def load_onnx_model(
349349
def interpret_trt_plugins_precision_flag(
350350
onnx_model: onnx.ModelProto,
351351
trt_plugins_precision: list[str],
352-
quantize_mode: str,
352+
quantize_mode: str = "int8",
353353
) -> tuple[dict, dict]:
354354
"""Convert custom ops precision flag to dictionaries with custom op and I/O indices to be cast/quantized.
355355

0 commit comments

Comments
 (0)