|
9 | 9 |
|
10 | 10 | import onnx |
11 | 11 | import onnx_ir as ir |
12 | | -from onnx import helper |
13 | | -from onnx.defs import onnx_opset_version |
14 | 12 |
|
15 | | -import onnxscript |
16 | 13 | import onnxscript.type_annotation |
17 | 14 | from onnxscript import values |
18 | | -from onnxscript.onnx_types import ONNXType |
19 | 15 | from onnxscript.sourceinfo import SourceInfo |
20 | 16 |
|
21 | 17 | logger = logging.getLogger("onnxscript") |
22 | 18 |
|
23 | 19 |
|
24 | | -def _format(seq: Sequence[Any], prefix: str, sep: str, suffix: str, formatter=str): |
25 | | - """Formats a sequence of objects into a string.""" |
26 | | - return prefix + sep.join([formatter(x) for x in seq]) + suffix |
27 | | - |
28 | | - |
29 | 20 | def select_ir_version(version: int, domain: str = "") -> int: |
30 | 21 | """Selects a suitable ONNX ir_version for a given opset version.""" |
31 | 22 | if domain == "": |
32 | 23 | domain = "ai.onnx" |
33 | | - if (domain, version) not in helper.OP_SET_ID_VERSION_MAP: |
34 | | - return max(v for k, v in helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx") |
35 | | - return helper.OP_SET_ID_VERSION_MAP[domain, version] |
| 24 | + if (domain, version) not in onnx.helper.OP_SET_ID_VERSION_MAP: |
| 25 | + return max( |
| 26 | + v for k, v in onnx.helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx" |
| 27 | + ) |
| 28 | + return onnx.helper.OP_SET_ID_VERSION_MAP[domain, version] |
36 | 29 |
|
37 | 30 |
|
38 | 31 | TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue |
@@ -76,9 +69,7 @@ def assigned_names(self) -> Sequence[str]: |
76 | 69 |
|
77 | 70 | @property |
78 | 71 | def inputs(self) -> Sequence[ir.Value]: |
79 | | - return ( |
80 | | - self.ir_function.inputs |
81 | | - ) # [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)] |
| 72 | + return self.ir_function.inputs |
82 | 73 |
|
83 | 74 | @property |
84 | 75 | def attrs(self) -> Sequence[ir.Attr]: |
@@ -116,106 +107,9 @@ def add_attr_parameter(self, attr: ir.Attr) -> None: |
116 | 107 | self.ordered_inputs_and_attrs.append(attr) |
117 | 108 | self.ir_function.attributes.add(attr) |
118 | 109 |
|
119 | | - def debug_print(self): |
120 | | - if logger.isEnabledFor(logging.DEBUG): |
121 | | - logger.debug(str(self.ir_function)) |
122 | | - |
123 | 110 | def add_nested_function(self, fun: IRFunction) -> None: |
124 | 111 | self.nested_functions[fun.name] = fun |
125 | 112 |
|
126 | | - def to_model_proto( |
127 | | - self, |
128 | | - functions=None, |
129 | | - io_types: Optional[ONNXType] = None, |
130 | | - input_types: Optional[Sequence[ONNXType]] = None, |
131 | | - output_types: Optional[Sequence[ONNXType]] = None, |
132 | | - value_infos: dict[str, ONNXType] | None = None, |
133 | | - opset_version: int | None = None, |
134 | | - **kwargs, |
135 | | - ) -> onnx.ModelProto: |
136 | | - """Converts this instance into a `onnx.ModelProto`. |
137 | | -
|
138 | | - Args: |
139 | | - functions: A list of functions to include in the model. |
140 | | - By default, all functions called at least once are included. |
141 | | - io_types: When specified, all the inputs/outputs of the model |
142 | | - are set to be of this type. |
143 | | - input_types: When specified, all the inputs of the model |
144 | | - are set to be of the corresponding type in this list. |
145 | | - output_types: When specified, all the outputs of the model |
146 | | - are set to be of the corresponding type in this list. |
147 | | - value_infos: A dictionary mapping intermediate variable names to ONNX types. |
148 | | - Used to set value_info for intermediate variables. |
149 | | - opset_version: The standard opset version to use for the model if it |
150 | | - cannot be inferred. Otherwise defaults to the current opset version. |
151 | | - kwargs: Additional parameters given to function :func:`onnx.helper.make_model`. |
152 | | -
|
153 | | - Returns: |
154 | | - An instance of :class:`onnx.ModelProto`. |
155 | | - """ |
156 | | - value_infos = ( |
157 | | - [ |
158 | | - onnx.helper.make_value_info(name, type.to_type_proto()) |
159 | | - for name, type in value_infos.items() |
160 | | - ] |
161 | | - if value_infos |
162 | | - else None |
163 | | - ) |
164 | | - sub_functions = self.get_called_functions() |
165 | | - graph = self.to_graph_proto(use_default_type=False) |
166 | | - if value_infos: |
167 | | - graph.value_info.extend(value_infos) |
168 | | - if io_types is not None: |
169 | | - for input in graph.input: |
170 | | - if not input.HasField("type"): |
171 | | - input.type.CopyFrom(io_types.to_type_proto()) |
172 | | - for output in graph.output: |
173 | | - if not output.HasField("type"): |
174 | | - output.type.CopyFrom(io_types.to_type_proto()) |
175 | | - if input_types is not None: |
176 | | - for input, type in zip(graph.input, input_types): |
177 | | - input.type.CopyFrom(type.to_type_proto()) |
178 | | - if output_types is not None: |
179 | | - for output, type in zip(graph.output, output_types): |
180 | | - output.type.CopyFrom(type.to_type_proto()) |
181 | | - if functions is None: |
182 | | - functions = sub_functions.values() |
183 | | - else: |
184 | | - |
185 | | - def to_proto(f): |
186 | | - if isinstance(f, onnx.FunctionProto): |
187 | | - return f |
188 | | - if isinstance(f, onnxscript.OnnxFunction): |
189 | | - return f.to_function_proto() |
190 | | - raise TypeError("Expected a value of type FunctionProto of OnnxFunction") |
191 | | - |
192 | | - functions = [to_proto(f) for f in functions] |
193 | | - |
194 | | - opsets = self.ir_function.opset_imports.copy() |
195 | | - |
196 | | - for proto in functions: |
197 | | - if proto.domain not in opsets: |
198 | | - opsets[proto.domain] = 1 |
199 | | - # TODO(rama): Handle conflicts with appropriate error/warning message. |
200 | | - for opset in proto.opset_import: |
201 | | - if opset.domain not in opsets: |
202 | | - opsets[opset.domain] = opset.version |
203 | | - |
204 | | - if "" not in opsets: |
205 | | - # No operator is using the standard opset. |
206 | | - # Use the specified version if provided or the default value. |
207 | | - opsets[""] = opset_version if opset_version is not None else onnx_opset_version() |
208 | | - |
209 | | - if "ir_version" not in kwargs: |
210 | | - kwargs["ir_version"] = select_ir_version(opsets[""]) |
211 | | - opset_imports = [ |
212 | | - onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() |
213 | | - ] |
214 | | - |
215 | | - return helper.make_model( |
216 | | - graph, opset_imports=opset_imports, functions=functions, **kwargs |
217 | | - ) |
218 | | - |
219 | 113 | def get_called_functions(self) -> dict[str, onnx.FunctionProto]: |
220 | 114 | called_functions: dict[str, values.OnnxFunction] = {} |
221 | 115 |
|
@@ -297,7 +191,7 @@ def add_stmt( |
297 | 191 | attrs: Sequence[ir.Attr], |
298 | 192 | ) -> Sequence[ir.Value]: |
299 | 193 | output_values = [ir.Value(name=o) for o in results] |
300 | | - attributes = attrs # [ir.from_proto(a.attr_proto) for a in attrs] |
| 194 | + attributes = attrs |
301 | 195 | node = ir.Node( |
302 | 196 | domain=callee.opset.domain, |
303 | 197 | version=callee.opset.version, |
|
0 commit comments