Skip to content

Commit 7b8cd9b

Browse files
committed
IR cleanup
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent ddce3b3 commit 7b8cd9b

File tree

1 file changed

+21
-61
lines changed

1 file changed

+21
-61
lines changed

onnxscript/irbuilder.py

Lines changed: 21 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -
7474
self.name = varname
7575
self.info = sourceinfo
7676
self.typeinfo = typeinfo
77+
if typeinfo is None:
78+
self.value = ir.Value(name=varname)
79+
else:
80+
type_and_shape = ir.from_proto(typeinfo.to_type_proto())
81+
self.value = ir.Value(
82+
name=varname, type=type_and_shape.type, shape=type_and_shape.shape
83+
)
7784

7885
def __str__(self):
7986
return self.name
@@ -109,33 +116,7 @@ def _opt_var_to_str(x):
109116
return "" if x is None else str(x)
110117

111118

112-
class IRAttributeValue:
113-
"""An attribute value (representing an actual parameter).
114-
115-
Attributes:
116-
name: The name of the attribute.
117-
type: The type of the attribute.
118-
attr_proto: The attribute proto.
119-
"""
120-
121-
def __init__(self, attrproto: onnx.AttributeProto) -> None:
122-
if not isinstance(attrproto, onnx.AttributeProto):
123-
raise TypeError(f"Expected onnx.AttributeProto not {type(attrproto)!r}.")
124-
self.attr_proto = attrproto
125-
126-
def __str__(self):
127-
if self.attr_proto.HasField("ref_attr_name"):
128-
return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}"
129-
# self.name + " = " + self.value
130-
return helper.printable_attribute(self.attr_proto)
131-
132-
@property
133-
def name(self) -> str:
134-
return self.attr_proto.name
135-
136-
@property
137-
def type(self) -> onnx.AttributeProto.AttributeType:
138-
return self.attr_proto.type
119+
IRAttributeValue = ir.Attr
139120

140121

141122
@dataclasses.dataclass(frozen=True)
@@ -202,35 +183,19 @@ def args(self) -> Sequence[Optional[str]]:
202183
return [x.name if x is not None else None for x in self.node.inputs]
203184

204185
@property
205-
def attrs(self) -> Sequence[IRAttributeValue]:
206-
return [IRAttributeValue(ir.to_proto(a)) for a in self.node.attributes.values()]
186+
def attrs(self) -> Sequence[ir.Attr]:
187+
return list(self.node.attributes.values())
207188

208189
def __str__(self):
209-
lhs = ", ".join(self.output_names)
210-
attrs = ""
211-
if self.attrs:
212-
attrs = _format(self.attrs, "<", ", ", ">")
213-
214-
args = _format(self.args, "(", ", ", ")", _opt_var_to_str)
215-
domain = self.callee.opset.domain
216-
opname = self.callee.name
217-
callee = f"{domain}.{opname}" if (domain != "") else opname
218-
return f"{lhs} = {callee} {attrs}{args}"
190+
return str(self.node)
219191

220192
def debug_print(self):
221193
if logger.isEnabledFor(logging.DEBUG):
222194
logger.debug("%s: %s", type(self), self)
223195

224196
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
225-
n = helper.make_node(
226-
self.callee.name,
227-
[_opt_var_to_str(x) for x in self.args],
228-
self.output_names,
229-
domain=self.callee.opset.domain,
230-
name=node_name,
231-
)
232-
for a in self.attrs:
233-
n.attribute.append(a.attr_proto)
197+
n = ir.to_proto(self.node)
198+
n.name = node_name
234199
return n
235200

236201
@property
@@ -537,11 +502,11 @@ def add_stmt(
537502
results: Sequence[str],
538503
callee: values.Op,
539504
inputs: Sequence[Optional[ir.Value]],
540-
attrs: Sequence[IRAttributeValue],
505+
attrs: Sequence[ir.Attr],
541506
sub_functions=None,
542507
) -> Sequence[ir.Value]:
543508
output_values = [ir.Value(name=o) for o in results]
544-
attributes = [ir.from_proto(a.attr_proto) for a in attrs]
509+
attributes = attrs # [ir.from_proto(a.attr_proto) for a in attrs]
545510
node = ir.Node(
546511
domain=callee.opset.domain,
547512
version=callee.opset.version,
@@ -574,14 +539,9 @@ def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None
574539
var = IRVar(varname, typeinfo, sourceinfo)
575540
fn.append_output(var)
576541

577-
def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue:
578-
return IRAttributeValue(attrproto)
579-
580-
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
581-
proto = onnx.AttributeProto()
582-
proto.name = attrname
583-
proto.ref_attr_name = refname
584-
attr_type = ta.pytype_to_attrtype(pytype)
585-
assert attr_type is not None
586-
proto.type = attr_type
587-
return IRAttributeValue(proto)
542+
def make_attr(self, attrproto: onnx.AttributeProto) -> ir.Attr:
543+
return ir.from_proto(attrproto)
544+
545+
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> ir.Attr:
546+
attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype))
547+
return ir.Attr(attrname, attr_type, None, refname)

0 commit comments

Comments
 (0)