-
Notifications
You must be signed in to change notification settings - Fork 105
Expand file tree
/
Copy path__init__.py
More file actions
179 lines (148 loc) · 6.87 KB
/
__init__.py
File metadata and controls
179 lines (148 loc) · 6.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
__all__ = [
"ConvertVersionPass",
"convert_version",
]
import logging
import onnx
import onnx_ir.passes.common as common_passes
from onnxscript import ir
from onnxscript.version_converter import _c_api_utils, _version_converter
logger = logging.getLogger(__name__)
class ConvertVersionPass(ir.passes.InPlacePass):
"""Convert the model to the specified ONNX opset version.
This pass leverages the onnxscript version converter to convert the model. If
the conversion is not supported, it falls back to the onnx C API to convert
the model. This pass is in-place.
The pass is an no-op if the c-api fails.
Attributes:
target_version: The target ONNX opset version to convert the model to.
fallback: Whether to fallback to the onnx version converter if the
target version is not supported. Default is False.
"""
def __init__(self, target_version: int, fallback: bool = False) -> None:
super().__init__()
self.target_version = target_version
self.fallback = fallback
self.convert_pass = ir.passes.Sequential(
common_passes.InlinePass(),
_ConvertVersionPassRequiresInline(
target_version=target_version,
fallback=fallback,
),
common_passes.RemoveUnusedNodesPass(),
common_passes.RemoveUnusedFunctionsPass(),
common_passes.RemoveUnusedOpsetsPass(),
)
def call(self, model: ir.Model) -> ir.passes.PassResult:
return self.convert_pass(model)
class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
"""Convert the model to the specified ONNX opset version.
This pass leverages the onnxscript version converter to convert the model. If
the conversion is not supported, it falls back to the onnx C API to convert
the model. This pass is in-place.
The pass is an no-op if the c-api fails.
Attributes:
target_version: The target ONNX opset version to convert the model to.
fallback: Whether to fallback to the onnx version converter if the
target version is not supported.
"""
def __init__(self, target_version: int, fallback: bool) -> None:
super().__init__()
self.target_version = target_version
self.fallback = fallback
def call(self, model: ir.Model) -> ir.passes.PassResult:
if model.functions:
raise ValueError(
"The model contains functions. The version conversion pass does not support "
"functions. Please use `common_passes.InlinePass` to inline the "
f"functions before applying this pass ({self.__class__.__name__})."
)
if "" in model.graph.opset_imports:
onnx_opset_version = model.graph.opset_imports[""]
if onnx_opset_version == self.target_version:
# No need to convert the version
return ir.passes.PassResult(model, False)
# When fallback is disabled, always use the onnxscript version converter;
# When fallback is enabled, use the onnxscript version converter
# if the target version is supported. Otherwise, use the onnx C API
# to convert the model.
if not self.fallback or _version_converter.version_supported(
model, self.target_version
):
_version_converter.convert_version(
model,
target_version=self.target_version,
)
return ir.passes.PassResult(model, True)
if not self.fallback:
logger.warning(
"The model version conversion is not supported by the onnxscript version converter "
"and fallback is disabled. The model was not modified"
" (target version: %d). "
"Set fallback=True to enable fallback to the onnx c-api version converter.",
self.target_version,
)
return ir.passes.PassResult(model, False)
else:
logger.warning(
"The model version conversion is not supported by the onnxscript version converter "
"and fallback is enabled. The model will be converted using the onnx C API "
"(target version: %d).",
self.target_version,
)
# If the onnxscript version converter does not support the conversion,
# we can use the onnx C API to convert the model
def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto:
"""Partial function to check the model."""
return onnx.version_converter.convert_version(
proto, target_version=self.target_version
)
try:
converted_proto = _c_api_utils.call_onnx_api(
func=_partial_convert_version, model=model
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning(
"Failed to convert the model to the target version %d using the ONNX C API. "
"The model was not modified",
self.target_version,
exc_info=e,
)
return ir.passes.PassResult(model, False)
converted_model = ir.from_proto(converted_proto)
# Recover the initializers in the converted model
for input in converted_model.graph.inputs:
if input.name in model.graph.initializers:
input.const_value = model.graph.initializers[input.name].const_value
converted_model.graph.register_initializer(input)
user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)]
converted_model.graph.inputs.clear()
converted_model.graph.inputs.extend(user_inputs)
# Return the converted graph to the original model to keep the pass in-place
model.graph = converted_model.graph
return ir.passes.PassResult(model, True)
def convert_version(
model: ir.Model | onnx.ModelProto, target_version: int, fallback=None
) -> None:
"""Convert the model to the specified ONNX opset version.
Args:
model: The model to convert.
target_version: The target ONNX opset version.
fallback: Whether to fallback to the onnx version converter if the
target version is not supported. Default is False.
"""
if isinstance(model, onnx.ModelProto):
model_proto = model
model = ir.from_proto(model)
else:
model_proto = None
assert isinstance(model, ir.Model)
ConvertVersionPass(target_version=target_version, fallback=fallback)(model)
if model_proto is not None:
# Update the model proto in-place
model_proto.graph.Clear()
del model_proto.functions[:]
model_proto.graph.CopyFrom(ir.to_proto(model.graph))