Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions modelopt/onnx/autocast/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ def get_parser() -> argparse.ArgumentParser:
"libraries are in the PATH or LD_LIBRARY_PATH variables."
),
)
parser.add_argument(
"--trt_plugins_precision",
type=str,
default=[],
nargs="+",
help=(
"A space-separated list indicating the precision for each custom op. "
"Each item should have the format <op_type>:<precision> (all inputs and outputs have the same precision) "
"or <op_type>:[<inp1_precision>,<inp2_precision>,...]:[<out1_precision>,<out2_precision>,...] "
"(inputs and outputs can have different precisions), where precision can be fp32 (default) or fp16."
"For example: op_type_1:fp16 op_type_2:[fp16,fp32]:[fp16]."
),
)

return parser

Expand Down Expand Up @@ -192,6 +205,7 @@ def main(argv=None):
init_conversion_max_bytes=args.init_conversion_max_bytes,
providers=args.providers,
trt_plugins=args.trt_plugins,
trt_plugins_precision=args.trt_plugins_precision,
max_depth_of_reduction=args.max_depth_of_reduction,
)

Expand Down
9 changes: 8 additions & 1 deletion modelopt/onnx/autocast/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def convert_to_mixed_precision(
init_conversion_max_bytes: int | None = None,
providers: list[str] = ["cpu"],
trt_plugins: list[str] = [],
trt_plugins_precision: list[str] = [],
max_depth_of_reduction: int | None = None,
) -> onnx.ModelProto:
"""Convert model to mixed precision.
Expand All @@ -78,6 +79,7 @@ def convert_to_mixed_precision(
runtime.
providers: List of ORT execution providers.
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
trt_plugins_precision: List indicating the precision for each custom op.
max_depth_of_reduction: Maximum depth of reduction for node classification.

Returns:
Expand All @@ -92,7 +94,11 @@ def convert_to_mixed_precision(
# Otherwise, prefer to keep the original opset version unless it's very old
min_opset = 22 if low_precision_type == "bf16" else 13
graph_sanitizer = GraphSanitizer(
model, min_opset, trt_plugins=trt_plugins, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT
model,
min_opset,
trt_plugins=trt_plugins,
trt_plugins_precision=trt_plugins_precision,
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
)
graph_sanitizer.sanitize()
model = graph_sanitizer.model
Expand All @@ -118,6 +124,7 @@ def convert_to_mixed_precision(
init_max=init_max,
custom_rule=custom_rule,
max_depth_of_reduction=max_depth_of_reduction,
custom_ops_low_precision_nodes=graph_sanitizer.custom_ops_low_precision_nodes or [],
)

precision_converter = PrecisionConverter(
Expand Down
19 changes: 19 additions & 0 deletions modelopt/onnx/autocast/graphsanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import modelopt.onnx.autocast.utils as utils
import modelopt.onnx.utils as onnx_utils
from modelopt.onnx.autocast.logging_config import logger
from modelopt.onnx.quantization.graph_utils import cast_custom_ops
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag


class GraphSanitizer:
Expand All @@ -34,6 +36,7 @@ def __init__(
min_opset: int = 13,
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
trt_plugins_precision: list[str] | None = [],
) -> None:
"""Initialize GraphSanitizer.

Expand All @@ -48,7 +51,9 @@ def __init__(
self.max_ir_version = max_ir_version
self.standard_ops = {schema.name for schema in onnx.defs.get_all_schemas()}
self.custom_ops = None
self.custom_ops_low_precision_nodes = []
self.trt_plugins = trt_plugins
self.trt_plugins_precision = trt_plugins_precision or []

def sanitize(self) -> None:
"""Sanitize the model graph.
Expand All @@ -67,6 +72,7 @@ def sanitize(self) -> None:
self.cleanup_model()
self.set_ir_version(self.max_ir_version)
self.convert_fp64_to_fp32()
self.ensure_custom_ops_precision()

def convert_fp64_to_fp32(self) -> None:
"""Convert FP64 initializers, I/O types, and specific nodes to FP32."""
Expand All @@ -88,6 +94,19 @@ def convert_fp64_to_fp32(self) -> None:
logger.info("Converted FP64 initializers, I/O types, and nodes to FP32")
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)

def ensure_custom_ops_precision(self) -> None:
"""Ensure that custom ops run in the requested precision."""
custom_ops_to_cast, _ = interpret_trt_plugins_precision_flag(
self.model,
self.trt_plugins_precision,
)
if custom_ops_to_cast.get("fp16", {}):
self.model = cast_custom_ops(self.model, custom_ops_to_cast["fp16"])
self.custom_ops_low_precision_nodes = [
n.name for n in self.model.graph.node if n.op_type in custom_ops_to_cast["fp16"]
]
logger.info("Ensured custom ops precision")

def find_custom_nodes(self) -> None:
"""Find custom nodes in the model.

Expand Down
11 changes: 8 additions & 3 deletions modelopt/onnx/autocast/nodeclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def __init__(
data_max: float | None = 1000.0,
init_max: float | None = np.finfo(np.float16).max,
max_depth_of_reduction: int | None = None,
custom_ops_low_precision_nodes: list[str] | None = None,
):
"""Initialize the node classifier.

Expand All @@ -375,6 +376,7 @@ def __init__(
data_max: Maximum absolute value allowed for node I/O.
init_max: Maximum absolute value allowed for initializers.
max_depth_of_reduction: Maximum depth of reduction allowed in low precision.
custom_ops_low_precision_nodes: List of custom op node names to convert to low precision.
"""
self.model = model
self.node_to_init_map = node_to_init_map
Expand All @@ -387,6 +389,7 @@ def __init__(
self.data_max = data_max
self.init_max = init_max
self.max_depth_of_reduction = max_depth_of_reduction
self.custom_ops_low_precision_nodes = custom_ops_low_precision_nodes

def _gen_exclude_node_rules(self, reference_data):
"""Generate list of rules for blocking nodes from precision conversion.
Expand Down Expand Up @@ -446,12 +449,14 @@ def run(self, ref_outputs_dict=None):
"""
exclude_node_rules = self._gen_exclude_node_rules(ref_outputs_dict)
include_node_rules = self._gen_include_node_rules()
low_precision_nodes = []
low_precision_nodes = self.custom_ops_low_precision_nodes or []
high_precision_nodes = []
for node in self.model.graph.node:
# If any condition is met - node will be executed in high precision
if any(rule.check(node) for rule in exclude_node_rules) and not any(
rule.check(node) for rule in include_node_rules
if (
node.name not in low_precision_nodes
and any(rule.check(node) for rule in exclude_node_rules)
and not any(rule.check(node) for rule in include_node_rules)
):
high_precision_nodes.append(node.name)
else:
Expand Down
2 changes: 1 addition & 1 deletion modelopt/onnx/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def load_onnx_model(
def interpret_trt_plugins_precision_flag(
onnx_model: onnx.ModelProto,
trt_plugins_precision: list[str],
quantize_mode: str,
quantize_mode: str = "int8",
) -> tuple[dict, dict]:
"""Convert custom ops precision flag to dictionaries with custom op and I/O indices to be cast/quantized.
Expand Down