Skip to content

Commit dda7977

Browse files
committed
Move to_model_proto
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent f7eb6d1 commit dda7977

File tree

5 files changed

+119
-118
lines changed

5 files changed

+119
-118
lines changed

onnxscript/converter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1475,7 +1475,6 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
14751475
self._current_fn = self.ir_builder.new_function(stmt.name, domain, True)
14761476
self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals)
14771477
fn_ir = self._translate_function_def_common(stmt)
1478-
fn_ir.debug_print()
14791478
self.this_module.add_function_def(fn_ir)
14801479
self._analyzer = None
14811480
return fn_ir

onnxscript/converter_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_unary_op(self):
235235
def test_subfunction_check_model(self):
236236
from tests.models import subfunction
237237

238-
model = subfunction.MyElu.function_ir.to_model_proto(producer_name="p2o")
238+
model = subfunction.MyElu.to_model_proto(producer_name="p2o")
239239
model = onnx.shape_inference.infer_shapes(model)
240240
onnx.checker.check_model(model)
241241

onnxscript/irbuilder.py

Lines changed: 7 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,23 @@
99

1010
import onnx
1111
import onnx_ir as ir
12-
from onnx import helper
13-
from onnx.defs import onnx_opset_version
1412

15-
import onnxscript
1613
import onnxscript.type_annotation
1714
from onnxscript import values
18-
from onnxscript.onnx_types import ONNXType
1915
from onnxscript.sourceinfo import SourceInfo
2016

2117
logger = logging.getLogger("onnxscript")
2218

2319

24-
def _format(seq: Sequence[Any], prefix: str, sep: str, suffix: str, formatter=str):
25-
"""Formats a sequence of objects into a string."""
26-
return prefix + sep.join([formatter(x) for x in seq]) + suffix
27-
28-
2920
def select_ir_version(version: int, domain: str = "") -> int:
3021
"""Selects a suitable ONNX ir_version for a given opset version."""
3122
if domain == "":
3223
domain = "ai.onnx"
33-
if (domain, version) not in helper.OP_SET_ID_VERSION_MAP:
34-
return max(v for k, v in helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx")
35-
return helper.OP_SET_ID_VERSION_MAP[domain, version]
24+
if (domain, version) not in onnx.helper.OP_SET_ID_VERSION_MAP:
25+
return max(
26+
v for k, v in onnx.helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx"
27+
)
28+
return onnx.helper.OP_SET_ID_VERSION_MAP[domain, version]
3629

3730

3831
TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue
@@ -76,9 +69,7 @@ def assigned_names(self) -> Sequence[str]:
7669

7770
@property
7871
def inputs(self) -> Sequence[ir.Value]:
79-
return (
80-
self.ir_function.inputs
81-
) # [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)]
72+
return self.ir_function.inputs
8273

8374
@property
8475
def attrs(self) -> Sequence[ir.Attr]:
@@ -116,106 +107,9 @@ def add_attr_parameter(self, attr: ir.Attr) -> None:
116107
self.ordered_inputs_and_attrs.append(attr)
117108
self.ir_function.attributes.add(attr)
118109

119-
def debug_print(self):
120-
if logger.isEnabledFor(logging.DEBUG):
121-
logger.debug(str(self.ir_function))
122-
123110
def add_nested_function(self, fun: IRFunction) -> None:
124111
self.nested_functions[fun.name] = fun
125112

126-
def to_model_proto(
127-
self,
128-
functions=None,
129-
io_types: Optional[ONNXType] = None,
130-
input_types: Optional[Sequence[ONNXType]] = None,
131-
output_types: Optional[Sequence[ONNXType]] = None,
132-
value_infos: dict[str, ONNXType] | None = None,
133-
opset_version: int | None = None,
134-
**kwargs,
135-
) -> onnx.ModelProto:
136-
"""Converts this instance into a `onnx.ModelProto`.
137-
138-
Args:
139-
functions: A list of functions to include in the model.
140-
By default, all functions called at least once are included.
141-
io_types: When specified, all the inputs/outputs of the model
142-
are set to be of this type.
143-
input_types: When specified, all the inputs of the model
144-
are set to be of the corresponding type in this list.
145-
output_types: When specified, all the outputs of the model
146-
are set to be of the corresponding type in this list.
147-
value_infos: A dictionary mapping intermediate variable names to ONNX types.
148-
Used to set value_info for intermediate variables.
149-
opset_version: The standard opset version to use for the model if it
150-
cannot be inferred. Otherwise defaults to the current opset version.
151-
kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.
152-
153-
Returns:
154-
An instance of :class:`onnx.ModelProto`.
155-
"""
156-
value_infos = (
157-
[
158-
onnx.helper.make_value_info(name, type.to_type_proto())
159-
for name, type in value_infos.items()
160-
]
161-
if value_infos
162-
else None
163-
)
164-
sub_functions = self.get_called_functions()
165-
graph = self.to_graph_proto(use_default_type=False)
166-
if value_infos:
167-
graph.value_info.extend(value_infos)
168-
if io_types is not None:
169-
for input in graph.input:
170-
if not input.HasField("type"):
171-
input.type.CopyFrom(io_types.to_type_proto())
172-
for output in graph.output:
173-
if not output.HasField("type"):
174-
output.type.CopyFrom(io_types.to_type_proto())
175-
if input_types is not None:
176-
for input, type in zip(graph.input, input_types):
177-
input.type.CopyFrom(type.to_type_proto())
178-
if output_types is not None:
179-
for output, type in zip(graph.output, output_types):
180-
output.type.CopyFrom(type.to_type_proto())
181-
if functions is None:
182-
functions = sub_functions.values()
183-
else:
184-
185-
def to_proto(f):
186-
if isinstance(f, onnx.FunctionProto):
187-
return f
188-
if isinstance(f, onnxscript.OnnxFunction):
189-
return f.to_function_proto()
190-
raise TypeError("Expected a value of type FunctionProto of OnnxFunction")
191-
192-
functions = [to_proto(f) for f in functions]
193-
194-
opsets = self.ir_function.opset_imports.copy()
195-
196-
for proto in functions:
197-
if proto.domain not in opsets:
198-
opsets[proto.domain] = 1
199-
# TODO(rama): Handle conflicts with appropriate error/warning message.
200-
for opset in proto.opset_import:
201-
if opset.domain not in opsets:
202-
opsets[opset.domain] = opset.version
203-
204-
if "" not in opsets:
205-
# No operator is using the standard opset.
206-
# Use the specified version if provided or the default value.
207-
opsets[""] = opset_version if opset_version is not None else onnx_opset_version()
208-
209-
if "ir_version" not in kwargs:
210-
kwargs["ir_version"] = select_ir_version(opsets[""])
211-
opset_imports = [
212-
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
213-
]
214-
215-
return helper.make_model(
216-
graph, opset_imports=opset_imports, functions=functions, **kwargs
217-
)
218-
219113
def get_called_functions(self) -> dict[str, onnx.FunctionProto]:
220114
called_functions: dict[str, values.OnnxFunction] = {}
221115

@@ -297,7 +191,7 @@ def add_stmt(
297191
attrs: Sequence[ir.Attr],
298192
) -> Sequence[ir.Value]:
299193
output_values = [ir.Value(name=o) for o in results]
300-
attributes = attrs # [ir.from_proto(a.attr_proto) for a in attrs]
194+
attributes = attrs
301195
node = ir.Node(
302196
domain=callee.opset.domain,
303197
version=callee.opset.version,

onnxscript/values.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,23 @@
3030
from onnxscript import irbuilder, sourceinfo, type_annotation
3131
from onnxscript._internal import ast_utils, deprecation
3232
from onnxscript.ir import _schemas
33+
from onnxscript.onnx_types import ONNXType
3334

3435
_R = TypeVar("_R")
3536
_P = ParamSpec("_P")
3637

3738

39+
def select_ir_version(version: int, domain: str = "") -> int:
40+
"""Selects a suitable ONNX ir_version for a given opset version."""
41+
if domain == "":
42+
domain = "ai.onnx"
43+
if (domain, version) not in onnx.helper.OP_SET_ID_VERSION_MAP:
44+
return max(
45+
v for k, v in onnx.helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx"
46+
)
47+
return onnx.helper.OP_SET_ID_VERSION_MAP[domain, version]
48+
49+
3850
_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
3951
onnx.defs.OpSchema.AttrType.FLOAT: float,
4052
onnx.defs.OpSchema.AttrType.INT: int,
@@ -609,7 +621,103 @@ def to_model_proto(self, **kwargs):
609621

610622
# Merge kwargs specified in script-decorator with those specified in this call.
611623
merged_kw_args = {**self.kwargs, **kwargs}
612-
return self.function_ir.to_model_proto(**merged_kw_args)
624+
return self._to_model_proto(**merged_kw_args)
625+
626+
def _to_model_proto(
627+
self,
628+
functions=None,
629+
io_types: Optional[ONNXType] = None,
630+
input_types: Optional[Sequence[ONNXType]] = None,
631+
output_types: Optional[Sequence[ONNXType]] = None,
632+
value_infos: dict[str, ONNXType] | None = None,
633+
opset_version: int | None = None,
634+
**kwargs,
635+
) -> onnx.ModelProto:
636+
"""Converts this instance into a `onnx.ModelProto`.
637+
638+
Args:
639+
functions: A list of functions to include in the model.
640+
By default, all functions called at least once are included.
641+
io_types: When specified, all the inputs/outputs of the model
642+
are set to be of this type.
643+
input_types: When specified, all the inputs of the model
644+
are set to be of the corresponding type in this list.
645+
output_types: When specified, all the outputs of the model
646+
are set to be of the corresponding type in this list.
647+
value_infos: A dictionary mapping intermediate variable names to ONNX types.
648+
Used to set value_info for intermediate variables.
649+
opset_version: The standard opset version to use for the model if it
650+
cannot be inferred. Otherwise defaults to the current opset version.
651+
kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.
652+
653+
Returns:
654+
An instance of :class:`onnx.ModelProto`.
655+
"""
656+
value_infos = (
657+
[
658+
onnx.helper.make_value_info(name, type.to_type_proto())
659+
for name, type in value_infos.items()
660+
]
661+
if value_infos
662+
else None
663+
)
664+
665+
graph = self.function_ir.to_graph_proto(use_default_type=False)
666+
if value_infos:
667+
graph.value_info.extend(value_infos)
668+
if io_types is not None:
669+
for input in graph.input:
670+
if not input.HasField("type"):
671+
input.type.CopyFrom(io_types.to_type_proto())
672+
for output in graph.output:
673+
if not output.HasField("type"):
674+
output.type.CopyFrom(io_types.to_type_proto())
675+
if input_types is not None:
676+
for input, type in zip(graph.input, input_types):
677+
input.type.CopyFrom(type.to_type_proto())
678+
if output_types is not None:
679+
for output, type in zip(graph.output, output_types):
680+
output.type.CopyFrom(type.to_type_proto())
681+
if functions is None:
682+
sub_functions = self.function_ir.get_called_functions()
683+
functions = sub_functions.values()
684+
else:
685+
686+
def to_proto(f):
687+
if isinstance(f, onnx.FunctionProto):
688+
return f
689+
if isinstance(f, OnnxFunction):
690+
return f.to_function_proto()
691+
raise TypeError("Expected a value of type FunctionProto of OnnxFunction")
692+
693+
functions = [to_proto(f) for f in functions]
694+
695+
opsets = self.function_ir.ir_function.opset_imports.copy()
696+
697+
for proto in functions:
698+
if proto.domain not in opsets:
699+
opsets[proto.domain] = 1
700+
# TODO(rama): Handle conflicts with appropriate error/warning message.
701+
for opset in proto.opset_import:
702+
if opset.domain not in opsets:
703+
opsets[opset.domain] = opset.version
704+
705+
if "" not in opsets:
706+
# No operator is using the standard opset.
707+
# Use the specified version if provided or the default value.
708+
opsets[""] = (
709+
opset_version if opset_version is not None else onnx.defs.onnx_opset_version()
710+
)
711+
712+
if "ir_version" not in kwargs:
713+
kwargs["ir_version"] = select_ir_version(opsets[""])
714+
opset_imports = [
715+
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
716+
]
717+
718+
return onnx.helper.make_model(
719+
graph, opset_imports=opset_imports, functions=functions, **kwargs
720+
)
613721

614722

615723
class TracedOnnxFunction(Op):

tests/common/onnx_script_test_case.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _create_model_from_param(
144144
# there is not way from the onnx test case's model and feed to get TypeProto
145145
# in order to build a model.
146146
# we have to resolve the TypeProto from script function.
147-
local_function_model_proto = param.function.function_ir.to_model_proto(
147+
local_function_model_proto = param.function.to_model_proto(
148148
ir_version=ir_version
149149
)
150150
input_value_infos = []
@@ -202,7 +202,7 @@ def run_converter_test(
202202
param, onnx_case_model, ir_version=ir_version
203203
)
204204
else:
205-
model = param.function.function_ir.to_model_proto(
205+
model = param.function.to_model_proto(
206206
producer_name="call_clip", ir_version=ir_version
207207
)
208208
try:

0 commit comments

Comments
 (0)