Skip to content

Commit fae6609

Browse files
committed
IR builder cleanup
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 7b8cd9b commit fae6609

File tree

1 file changed

+30
-15
lines changed

1 file changed

+30
-15
lines changed

onnxscript/irbuilder.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -
7474
self.name = varname
7575
self.info = sourceinfo
7676
self.typeinfo = typeinfo
77-
if typeinfo is None:
77+
if typeinfo is None or not hasattr(typeinfo, "to_type_proto"):
7878
self.value = ir.Value(name=varname)
7979
else:
8080
type_and_shape = ir.from_proto(typeinfo.to_type_proto())
@@ -135,6 +135,7 @@ class IRAttributeParameter:
135135

136136
name: str
137137
type: onnx.AttributeProto.AttributeType
138+
attr: ir.Attr
138139
default_value: str | int | float | None = None
139140

140141
# TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.
@@ -193,9 +194,8 @@ def debug_print(self):
193194
if logger.isEnabledFor(logging.DEBUG):
194195
logger.debug("%s: %s", type(self), self)
195196

196-
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
197+
def to_node_proto(self) -> onnx.NodeProto:
197198
n = ir.to_proto(self.node)
198-
n.name = node_name
199199
return n
200200

201201
@property
@@ -208,8 +208,8 @@ class IRFunction:
208208
"""Represents a function in the IR."""
209209

210210
def __init__(self, name: str, domain: str = "") -> None:
211-
self.ir_graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name)
212-
self.domain = domain
211+
graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name)
212+
self.ir_function = ir.Function(domain, name, graph=graph, attributes=[])
213213
self.outputs: list[IRVar] = []
214214
self.stmts: list[IRStmt] = []
215215
self.called_functions: dict[str, onnx.FunctionProto] = {}
@@ -218,15 +218,20 @@ def __init__(self, name: str, domain: str = "") -> None:
218218
self.outer_scope_variables: dict[Any, Any] = {}
219219
self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = []
220220

221+
@property
222+
def domain(self) -> str:
223+
"""Returns the domain of this function."""
224+
return self.ir_function.domain
225+
221226
@property
222227
def docstring(self) -> str:
223228
"""Returns the docstring of this function."""
224-
return self.ir_graph.doc_string or ""
229+
return self.ir_function.doc_string or ""
225230

226231
@property
227232
def name(self) -> str:
228233
"""Returns the name of this function."""
229-
return self.ir_graph.name
234+
return self.ir_function.name
230235

231236
@property
232237
def assigned_names(self) -> Sequence[str]:
@@ -253,16 +258,23 @@ def __str__(self):
253258
return f"{self.name} {attrs}{inputs} => {outputs}{stmts}"
254259

255260
def append_stmt(self, stmt: IRStmt) -> None:
261+
count = len(self.stmts)
262+
node_name = f"n{count}"
263+
stmt.node.name = node_name
256264
self.stmts.append(stmt)
265+
self.ir_function.append(stmt.node)
257266

258-
def append_input(self, name: IRVar) -> None:
259-
self.ordered_inputs_and_attrs.append(name)
267+
def append_input(self, var: IRVar) -> None:
268+
self.ordered_inputs_and_attrs.append(var)
269+
self.ir_function.inputs.append(var.value)
260270

261-
def append_output(self, name: IRVar) -> None:
262-
self.outputs.append(name)
271+
def append_output(self, var: IRVar) -> None:
272+
self.outputs.append(var)
273+
self.ir_function.outputs.append(var.value)
263274

264275
def add_attr_parameter(self, attr: IRAttributeParameter) -> None:
265276
self.ordered_inputs_and_attrs.append(attr)
277+
self.ir_function.attributes.add(attr.attr)
266278

267279
def debug_print(self):
268280
if logger.isEnabledFor(logging.DEBUG):
@@ -407,7 +419,7 @@ def _to_graph_and_functions(
407419
called_functions.update(s.functions)
408420
called_functions.update(self.called_functions)
409421
graph = helper.make_graph(
410-
[s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)],
422+
[s.to_node_proto() for s in self.stmts],
411423
self.name,
412424
[x.to_value_info(use_default_type) for x in self.inputs],
413425
[y.to_value_info(use_default_type) for y in self.outputs],
@@ -450,7 +462,7 @@ def to_function_proto(self) -> onnx.FunctionProto:
450462
doesn't support it.
451463
"""
452464
opsets = self.get_opset_import()
453-
nodes = [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)]
465+
nodes = [s.to_node_proto() for s in self.stmts]
454466
for n in nodes:
455467
if n.domain not in opsets:
456468
opsets[n.domain] = 1 # TODO: how to get n.version?
@@ -494,7 +506,7 @@ def new_function(self, name: str, domain: str = "", register: bool = False) -> I
494506
return function
495507

496508
def add_docstring(self, fn: IRFunction, docstring: str):
497-
fn.ir_graph.doc_string = docstring
509+
fn.ir_function.doc_string = docstring
498510

499511
def add_stmt(
500512
self,
@@ -533,7 +545,10 @@ def add_attr_parameter(
533545
attribute_type: onnx.AttributeProto.AttributeType,
534546
default_value: int | float | str | None,
535547
) -> None:
536-
fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value))
548+
attr = ir.Attr(varname, ir.AttributeType(attribute_type), None, None)
549+
fn.add_attr_parameter(
550+
IRAttributeParameter(varname, attribute_type, attr, default_value)
551+
)
537552

538553
def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None:
539554
var = IRVar(varname, typeinfo, sourceinfo)

0 commit comments

Comments
 (0)