32
32
33
33
import modelopt .onnx .autocast .utils as utils
34
34
import modelopt .onnx .utils as onnx_utils
35
+ from modelopt .onnx .autocast .graphsanitizer import GraphSanitizer
35
36
from modelopt .onnx .autocast .logging_config import configure_logging , logger
36
37
37
38
configure_logging ()
@@ -73,6 +74,9 @@ def __init__(
73
74
low_precision_type : str = "fp16" ,
74
75
init_conversion_max_bytes : int | None = None ,
75
76
custom_ops : set [str ] | None = None ,
77
+ min_opset : int = 13 ,
78
+ max_ir_version : int | None = None ,
79
+ trt_plugins : list [str ] | None = [],
76
80
) -> None :
77
81
"""Initialize PrecisionConverter.
78
82
@@ -109,6 +113,9 @@ def __init__(
109
113
self .original_network_io .update (
110
114
{io .name : io .type .tensor_type .elem_type for io in self .model .graph .output }
111
115
)
116
+ self .min_opset = min_opset
117
+ self .max_ir_version = max_ir_version
118
+ self .trt_plugins = trt_plugins
112
119
113
120
def convert (
114
121
self ,
@@ -132,6 +139,8 @@ def convert(
132
139
"AutoCast can only operate on valid ONNX models, but the input model is invalid. See log for details."
133
140
)
134
141
142
+ self ._sanitize_model ()
143
+
135
144
# Filter out nodes that are not allowed to be in low precision
136
145
# This is done here and not in NodeClassifier because it is required for the model to be valid
137
146
high_precision_nodes , low_precision_nodes = self ._filter_unsupported_op_types (
@@ -1030,3 +1039,13 @@ def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
1030
1039
get_consumer_nodes = utils .get_consumer_nodes (self .model , const_producer .output [0 ])
1031
1040
return len (get_consumer_nodes ) == 1 and get_consumer_nodes [0 ] == node
1032
1041
return False
1042
+
1043
+ def _sanitize_model (self ):
1044
+ graph_sanitizer = GraphSanitizer (
1045
+ self .model ,
1046
+ self .min_opset ,
1047
+ trt_plugins = self .trt_plugins ,
1048
+ max_ir_version = self .max_ir_version ,
1049
+ )
1050
+ graph_sanitizer .sanitize ()
1051
+ self .model = graph_sanitizer .model
0 commit comments