Skip to content

Commit 4527f92

Browse files
committed
More cleanup
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent dda7977 commit 4527f92

File tree

3 files changed

+19
-40
lines changed

3 files changed

+19
-40
lines changed

onnxscript/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,7 +1313,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13131313
inputs = [o_loop_bound, o_loop_condition] + [
13141314
self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars
13151315
]
1316-
attrs = [self._make_onnx_attr("body", body.ir_function.graph)]
1316+
attrs = [self._make_onnx_attr("body", body.graph)]
13171317
info = self._source_of(loop_stmt)
13181318

13191319
def rename(x):
@@ -1378,7 +1378,7 @@ def _translate_block(
13781378
typeinfo = None
13791379
self.ir_builder.add_output(self._current_fn, ovar.name, typeinfo, source)
13801380
graph = self._exit_scope()
1381-
return graph.ir_function.graph
1381+
return graph.graph
13821382

13831383
def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
13841384
"""Translate a nested function definition."""

onnxscript/irbuilder.py

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,63 +31,42 @@ def select_ir_version(version: int, domain: str = "") -> int:
3131
TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue
3232

3333

34-
class IRFunction:
34+
class IRFunction(ir.Function):
3535
"""Represents a function in the IR."""
3636

3737
def __init__(self, name: str, domain: str = "") -> None:
3838
graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name)
39-
self.ir_function = ir.Function(domain, name, graph=graph, attributes=[])
39+
super().__init__(domain, name, graph=graph, attributes=[])
4040
self.ordered_inputs_and_attrs: list[Union[ir.Value, ir.Attr]] = []
4141

4242
# a dictionary of nested function-definitions
4343
self.nested_functions: dict[str, IRFunction] = {}
4444
self.outer_scope_variables: dict[Any, Any] = {}
4545

46-
@property
47-
def outputs(self) -> Sequence[ir.Value]:
48-
return self.ir_function.outputs
49-
50-
@property
51-
def domain(self) -> str:
52-
"""Returns the domain of this function."""
53-
return self.ir_function.domain
54-
5546
@property
5647
def docstring(self) -> str:
5748
"""Returns the docstring of this function."""
58-
return self.ir_function.doc_string or ""
59-
60-
@property
61-
def name(self) -> str:
62-
"""Returns the name of this function."""
63-
return self.ir_function.name
49+
return self.doc_string or ""
6450

6551
@property
6652
def assigned_names(self) -> Sequence[str]:
6753
"""Returns the list of variables assigned to by this function."""
68-
return [v.name for n in self.ir_function for v in n.outputs]
69-
70-
@property
71-
def inputs(self) -> Sequence[ir.Value]:
72-
return self.ir_function.inputs
54+
return [v.name for n in self for v in n.outputs]
7355

7456
@property
7557
def attrs(self) -> Sequence[ir.Attr]:
7658
return [attr for attr in self.ordered_inputs_and_attrs if isinstance(attr, ir.Attr)]
7759

78-
def __str__(self):
79-
return str(self.ir_function)
80-
8160
def append_node(self, node: ir.Node) -> None:
82-
count = len(self.ir_function)
61+
count = len(self)
8362
node.name = f"n{count}"
84-
self.ir_function.append(node)
63+
self.append(node)
8564
domain = node.domain
8665
version = node.version
87-
if domain not in self.ir_function.opset_imports:
88-
self.ir_function.opset_imports[domain] = version
66+
if domain not in self.opset_imports:
67+
self.opset_imports[domain] = version
8968
else:
90-
existing_version = self.ir_function.opset_imports[domain]
69+
existing_version = self.opset_imports[domain]
9170
if existing_version != version:
9271
warnings.warn(
9372
f"Version conflict: domain: {domain!r}, "
@@ -98,14 +77,14 @@ def append_node(self, node: ir.Node) -> None:
9877

9978
def append_input(self, var: ir.Value) -> None:
10079
self.ordered_inputs_and_attrs.append(var)
101-
self.ir_function.inputs.append(var)
80+
self.inputs.append(var)
10281

10382
def append_output(self, var: ir.Value) -> None:
104-
self.ir_function.outputs.append(var)
83+
self.outputs.append(var)
10584

10685
def add_attr_parameter(self, attr: ir.Attr) -> None:
10786
self.ordered_inputs_and_attrs.append(attr)
108-
self.ir_function.attributes.add(attr)
87+
self.attributes.add(attr)
10988

11089
def add_nested_function(self, fun: IRFunction) -> None:
11190
self.nested_functions[fun.name] = fun
@@ -114,7 +93,7 @@ def get_called_functions(self) -> dict[str, onnx.FunctionProto]:
11493
called_functions: dict[str, values.OnnxFunction] = {}
11594

11695
def visit(function_ir: IRFunction):
117-
for node in ir.traversal.RecursiveGraphIterator(function_ir.ir_function.graph):
96+
for node in ir.traversal.RecursiveGraphIterator(function_ir.graph):
11897
callee = node.meta.get("callee", None)
11998
if isinstance(callee, values.OnnxFunction):
12099
add(callee)
@@ -139,11 +118,11 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto:
139118
an instance of :class:`onnx.GraphProto`
140119
"""
141120
del use_default_type # currently not used
142-
return ir.to_proto(self.ir_function.graph)
121+
return ir.to_proto(self.graph)
143122

144123
def to_function_proto(self) -> onnx.FunctionProto:
145124
"""Converts this instance into a `onnx.FunctionProto`."""
146-
return ir.to_proto(self.ir_function)
125+
return ir.to_proto(self)
147126

148127

149128
# IRBuilder: abstracts out details of the IR in the python-to-IR converter
@@ -180,7 +159,7 @@ def new_function(self, name: str, domain: str = "", register: bool = False) -> I
180159
return function
181160

182161
def add_docstring(self, fn: IRFunction, docstring: str):
183-
fn.ir_function.doc_string = docstring
162+
fn.doc_string = docstring
184163

185164
def add_stmt(
186165
self,

onnxscript/values.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ def to_proto(f):
692692

693693
functions = [to_proto(f) for f in functions]
694694

695-
opsets = self.function_ir.ir_function.opset_imports.copy()
695+
opsets = self.function_ir.opset_imports.copy()
696696

697697
for proto in functions:
698698
if proto.domain not in opsets:

0 commit comments

Comments
 (0)