Skip to content

Commit 0e0d11a

Browse files
committed
call sanitizer in precision converter
Signed-off-by: Ali Boubezari <[email protected]>
1 parent caf9d39 commit 0e0d11a

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import modelopt.onnx.autocast.utils as utils
3434
import modelopt.onnx.utils as onnx_utils
35+
from modelopt.onnx.autocast.graphsanitizer import GraphSanitizer
3536
from modelopt.onnx.autocast.logging_config import configure_logging, logger
3637

3738
configure_logging()
@@ -73,6 +74,9 @@ def __init__(
7374
low_precision_type: str = "fp16",
7475
init_conversion_max_bytes: int | None = None,
7576
custom_ops: set[str] | None = None,
77+
min_opset: int = 13,
78+
max_ir_version: int | None = None,
79+
trt_plugins: list[str] | None = [],
7680
) -> None:
7781
"""Initialize PrecisionConverter.
7882
@@ -109,6 +113,9 @@ def __init__(
109113
self.original_network_io.update(
110114
{io.name: io.type.tensor_type.elem_type for io in self.model.graph.output}
111115
)
116+
self.min_opset = min_opset
117+
self.max_ir_version = max_ir_version
118+
self.trt_plugins = trt_plugins
112119

113120
def convert(
114121
self,
@@ -132,6 +139,8 @@ def convert(
132139
"AutoCast can only operate on valid ONNX models, but the input model is invalid. See log for details."
133140
)
134141

142+
self._sanitize_model()
143+
135144
# Filter out nodes that are not allowed to be in low precision
136145
# This is done here and not in NodeClassifier because it is required for the model to be valid
137146
high_precision_nodes, low_precision_nodes = self._filter_unsupported_op_types(
@@ -1030,3 +1039,13 @@ def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
10301039
get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0])
10311040
return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node
10321041
return False
1042+
1043+
def _sanitize_model(self):
1044+
graph_sanitizer = GraphSanitizer(
1045+
self.model,
1046+
self.min_opset,
1047+
trt_plugins=self.trt_plugins,
1048+
max_ir_version=self.max_ir_version,
1049+
)
1050+
graph_sanitizer.sanitize()
1051+
self.model = graph_sanitizer.model

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import modelopt.onnx.autocast.utils as utils
2222
import modelopt.onnx.utils as onnx_utils
23-
from modelopt.onnx.autocast.graphsanitizer import GraphSanitizer
2423
from modelopt.onnx.autocast.logging_config import configure_logging
2524
from modelopt.onnx.autocast.precisionconverter import PrecisionConverter
2625

@@ -31,6 +30,9 @@ def low_precision_onnx_type(low_precision_type_str):
3130
return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16
3231

3332

33+
LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10
34+
35+
3436
####################################################################################################
3537
# Testing with a basic GEMM->Add->Relu graph
3638
####################################################################################################
@@ -1079,17 +1081,16 @@ def test_casted_input_to_output_model(
10791081
):
10801082
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output
10811083

1082-
min_opset = 22 if low_precision_type == "bf16" else 13
1083-
graph_sanitizer = GraphSanitizer(model, min_opset)
1084-
graph_sanitizer.sanitize()
1085-
model = graph_sanitizer.model
10861084
converter = PrecisionConverter(
10871085
model,
10881086
value_info_map,
10891087
initializer_map,
10901088
node_to_init_map,
10911089
keep_io_types=keep_io_types,
10921090
low_precision_type=low_precision_type,
1091+
min_opset=22 if low_precision_type == "bf16" else 13,
1092+
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
1093+
trt_plugins=[],
10931094
)
10941095
converted_model = converter.convert(
10951096
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]

0 commit comments

Comments
 (0)