Skip to content

Commit 7653d8a

Browse files
committed
Add support for FP16-only custom ops
Signed-off-by: gcunhase <[email protected]>
1 parent d0e83ed commit 7653d8a

File tree

5 files changed

+50
-4
lines changed

5 files changed

+50
-4
lines changed

modelopt/onnx/autocast/__main__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,20 @@ def get_parser() -> argparse.ArgumentParser:
143143
"libraries are in the PATH or LD_LIBRARY_PATH variables."
144144
),
145145
)
146+
parser.add_argument(
147+
"--trt_plugins_precision",
148+
type=str,
149+
default=[],
150+
nargs="+",
151+
help=(
152+
"A space-separated list indicating the precision for each custom op. "
153+
"Each item should have the format <op_type>:<precision> (all inputs and outputs have the same precision) "
154+
"or <op_type>:[<inp1_precision>,<inp2_precision>,...]:[<out1_precision>,<out2_precision>,...] "
155+
"(inputs and outputs can have different precisions), where precision can be fp32 (default), "
156+
"fp16, int8, or fp8. Note that int8/fp8 should be the same as the quantization mode. "
157+
"For example: op_type_1:fp16 op_type_2:[int8,fp32]:[int8]."
158+
),
159+
)
146160

147161
return parser
148162

@@ -171,6 +185,7 @@ def main(argv=None):
171185
init_conversion_max_bytes=args.init_conversion_max_bytes,
172186
providers=args.providers,
173187
trt_plugins=args.trt_plugins,
188+
trt_plugins_precision=args.trt_plugins_precision,
174189
max_depth_of_reduction=args.max_depth_of_reduction,
175190
)
176191

modelopt/onnx/autocast/convert.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def convert_to_mixed_precision(
5656
init_conversion_max_bytes: int | None = None,
5757
providers: list[str] = ["cpu"],
5858
trt_plugins: list[str] = [],
59+
trt_plugins_precision: list[str] = [],
5960
max_depth_of_reduction: int | None = None,
6061
) -> onnx.ModelProto:
6162
"""Convert model to mixed precision.
@@ -74,6 +75,7 @@ def convert_to_mixed_precision(
7475
runtime.
7576
providers: List of ORT execution providers.
7677
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
78+
trt_plugins_precision: List indicating the precision for each custom op.
7779
max_depth_of_reduction: Maximum depth of reduction for node classification.
7880
7981
Returns:
@@ -88,7 +90,11 @@ def convert_to_mixed_precision(
8890
# Otherwise, prefer to keep the original opset version unless it's very old
8991
min_opset = 22 if low_precision_type == "bf16" else 13
9092
graph_sanitizer = GraphSanitizer(
91-
model, min_opset, trt_plugins=trt_plugins, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT
93+
model,
94+
min_opset,
95+
trt_plugins=trt_plugins,
96+
trt_plugins_precision=trt_plugins_precision,
97+
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
9298
)
9399
graph_sanitizer.sanitize()
94100
model = graph_sanitizer.model
@@ -112,6 +118,7 @@ def convert_to_mixed_precision(
112118
init_max=init_max,
113119
custom_rule=custom_rule,
114120
max_depth_of_reduction=max_depth_of_reduction,
121+
custom_ops_low_precision_nodes=graph_sanitizer.custom_ops_low_precision_nodes or [],
115122
)
116123

117124
precision_converter = PrecisionConverter(

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import modelopt.onnx.autocast.utils as utils
2424
import modelopt.onnx.utils as onnx_utils
2525
from modelopt.onnx.autocast.logging_config import logger
26+
from modelopt.onnx.quantization.graph_utils import cast_custom_ops
27+
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag
2628

2729

2830
class GraphSanitizer:
@@ -34,6 +36,7 @@ def __init__(
3436
min_opset: int = 13,
3537
max_ir_version: int | None = None,
3638
trt_plugins: list[str] | None = [],
39+
trt_plugins_precision: list[str] | None = [],
3740
) -> None:
3841
"""Initialize GraphSanitizer.
3942
@@ -48,7 +51,9 @@ def __init__(
4851
self.max_ir_version = max_ir_version
4952
self.standard_ops = {schema.name for schema in onnx.defs.get_all_schemas()}
5053
self.custom_ops = None
54+
self.custom_ops_low_precision_nodes = []
5155
self.trt_plugins = trt_plugins
56+
self.trt_plugins_precision = trt_plugins_precision or []
5257

5358
def sanitize(self) -> None:
5459
"""Sanitize the model graph.
@@ -67,6 +72,7 @@ def sanitize(self) -> None:
6772
self.cleanup_model()
6873
self.set_ir_version(self.max_ir_version)
6974
self.convert_fp64_to_fp32()
75+
self.ensure_custom_ops_precision()
7076

7177
def convert_fp64_to_fp32(self) -> None:
7278
"""Convert FP64 initializers, I/O types, and specific nodes to FP32."""
@@ -88,6 +94,19 @@ def convert_fp64_to_fp32(self) -> None:
8894
logger.info("Converted FP64 initializers, I/O types, and nodes to FP32")
8995
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)
9096

97+
def ensure_custom_ops_precision(self) -> None:
98+
"""Ensure that custom ops run in the requested precision."""
99+
custom_ops_to_cast, _ = interpret_trt_plugins_precision_flag(
100+
self.model,
101+
self.trt_plugins_precision,
102+
)
103+
if custom_ops_to_cast.get("fp16", {}):
104+
self.model = cast_custom_ops(self.model, custom_ops_to_cast["fp16"])
105+
self.custom_ops_low_precision_nodes = [
106+
n.name for n in self.model.graph.node if n.op_type in custom_ops_to_cast["fp16"]
107+
]
108+
logger.info("Ensured custom ops precision")
109+
91110
def find_custom_nodes(self) -> None:
92111
"""Find custom nodes in the model.
93112

modelopt/onnx/autocast/nodeclassifier.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def __init__(
336336
data_max: float | None = 1000.0,
337337
init_max: float | None = np.finfo(np.float16).max,
338338
max_depth_of_reduction: int | None = None,
339+
custom_ops_low_precision_nodes: list[str] | None = None,
339340
):
340341
"""Initialize the node classifier.
341342
@@ -349,6 +350,7 @@ def __init__(
349350
data_max: Maximum absolute value allowed for node I/O.
350351
init_max: Maximum absolute value allowed for initializers.
351352
max_depth_of_reduction: Maximum depth of reduction allowed in low precision.
353+
custom_ops_low_precision_nodes: List of custom op node names to convert to low precision.
352354
"""
353355
self.model = model
354356
self.node_to_init_map = node_to_init_map
@@ -359,6 +361,7 @@ def __init__(
359361
self.data_max = data_max
360362
self.init_max = init_max
361363
self.max_depth_of_reduction = max_depth_of_reduction
364+
self.custom_ops_low_precision_nodes = custom_ops_low_precision_nodes
362365

363366
def _gen_block_node_rules(self, reference_data):
364367
"""Generate list of rules for blocking nodes from precision conversion.
@@ -403,11 +406,13 @@ def run(self, ref_outputs_dict=None):
403406
tuple: Lists of node names (low_precision_nodes, high_precision_nodes).
404407
"""
405408
block_node_rules = self._gen_block_node_rules(ref_outputs_dict)
406-
low_precision_nodes = []
409+
low_precision_nodes = self.custom_ops_low_precision_nodes or []
407410
high_precision_nodes = []
408411
for node in self.model.graph.node:
409412
# If any condition is met - node will be executed in high precision
410-
if any(rule.check(node) for rule in block_node_rules):
413+
if node.name not in low_precision_nodes and any(
414+
rule.check(node) for rule in block_node_rules
415+
):
411416
high_precision_nodes.append(node.name)
412417
else:
413418
low_precision_nodes.append(node.name)

modelopt/onnx/trt_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def load_onnx_model(
349349
def interpret_trt_plugins_precision_flag(
350350
onnx_model: onnx.ModelProto,
351351
trt_plugins_precision: list[str],
352-
quantize_mode: str,
352+
quantize_mode: str = "int8",
353353
) -> tuple[dict, dict]:
354354
"""Convert custom ops precision flag to dictionaries with custom op and I/O indices to be cast/quantized.
355355

0 commit comments

Comments
 (0)