@@ -74,6 +74,13 @@ def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -
7474 self .name = varname
7575 self .info = sourceinfo
7676 self .typeinfo = typeinfo
77+ if typeinfo is None :
78+ self .value = ir .Value (name = varname )
79+ else :
80+ type_and_shape = ir .from_proto (typeinfo .to_type_proto ())
81+ self .value = ir .Value (
82+ name = varname , type = type_and_shape .type , shape = type_and_shape .shape
83+ )
7784
7885 def __str__ (self ):
7986 return self .name
@@ -109,33 +116,7 @@ def _opt_var_to_str(x):
109116 return "" if x is None else str (x )
110117
111118
112- class IRAttributeValue :
113- """An attribute value (representing an actual parameter).
114-
115- Attributes:
116- name: The name of the attribute.
117- type: The type of the attribute.
118- attr_proto: The attribute proto.
119- """
120-
121- def __init__ (self , attrproto : onnx .AttributeProto ) -> None :
122- if not isinstance (attrproto , onnx .AttributeProto ):
123- raise TypeError (f"Expected onnx.AttributeProto not { type (attrproto )!r} ." )
124- self .attr_proto = attrproto
125-
126- def __str__ (self ):
127- if self .attr_proto .HasField ("ref_attr_name" ):
128- return f"{ self .attr_proto .name } = @{ self .attr_proto .ref_attr_name } "
129- # self.name + " = " + self.value
130- return helper .printable_attribute (self .attr_proto )
131-
132- @property
133- def name (self ) -> str :
134- return self .attr_proto .name
135-
136- @property
137- def type (self ) -> onnx .AttributeProto .AttributeType :
138- return self .attr_proto .type
119+ IRAttributeValue = ir .Attr
139120
140121
141122@dataclasses .dataclass (frozen = True )
@@ -202,35 +183,19 @@ def args(self) -> Sequence[Optional[str]]:
202183 return [x .name if x is not None else None for x in self .node .inputs ]
203184
204185 @property
205- def attrs (self ) -> Sequence [IRAttributeValue ]:
206- return [ IRAttributeValue ( ir . to_proto ( a )) for a in self .node .attributes .values ()]
186+ def attrs (self ) -> Sequence [ir . Attr ]:
187+ return list ( self .node .attributes .values ())
207188
208189 def __str__ (self ):
209- lhs = ", " .join (self .output_names )
210- attrs = ""
211- if self .attrs :
212- attrs = _format (self .attrs , "<" , ", " , ">" )
213-
214- args = _format (self .args , "(" , ", " , ")" , _opt_var_to_str )
215- domain = self .callee .opset .domain
216- opname = self .callee .name
217- callee = f"{ domain } .{ opname } " if (domain != "" ) else opname
218- return f"{ lhs } = { callee } { attrs } { args } "
190+ return str (self .node )
219191
220192 def debug_print (self ):
221193 if logger .isEnabledFor (logging .DEBUG ):
222194 logger .debug ("%s: %s" , type (self ), self )
223195
224196 def to_node_proto (self , node_name : str ) -> onnx .NodeProto :
225- n = helper .make_node (
226- self .callee .name ,
227- [_opt_var_to_str (x ) for x in self .args ],
228- self .output_names ,
229- domain = self .callee .opset .domain ,
230- name = node_name ,
231- )
232- for a in self .attrs :
233- n .attribute .append (a .attr_proto )
197+ n = ir .to_proto (self .node )
198+ n .name = node_name
234199 return n
235200
236201 @property
@@ -537,11 +502,11 @@ def add_stmt(
537502 results : Sequence [str ],
538503 callee : values .Op ,
539504 inputs : Sequence [Optional [ir .Value ]],
540- attrs : Sequence [IRAttributeValue ],
505+ attrs : Sequence [ir . Attr ],
541506 sub_functions = None ,
542507 ) -> Sequence [ir .Value ]:
543508 output_values = [ir .Value (name = o ) for o in results ]
544- attributes = [ir .from_proto (a .attr_proto ) for a in attrs ]
509+ attributes = attrs # [ir.from_proto(a.attr_proto) for a in attrs]
545510 node = ir .Node (
546511 domain = callee .opset .domain ,
547512 version = callee .opset .version ,
@@ -574,14 +539,9 @@ def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None
574539 var = IRVar (varname , typeinfo , sourceinfo )
575540 fn .append_output (var )
576541
577- def make_attr (self , attrproto : onnx .AttributeProto ) -> IRAttributeValue :
578- return IRAttributeValue (attrproto )
579-
580- def make_attr_ref (self , attrname : str , refname : str , pytype : type ) -> IRAttributeValue :
581- proto = onnx .AttributeProto ()
582- proto .name = attrname
583- proto .ref_attr_name = refname
584- attr_type = ta .pytype_to_attrtype (pytype )
585- assert attr_type is not None
586- proto .type = attr_type
587- return IRAttributeValue (proto )
542+ def make_attr (self , attrproto : onnx .AttributeProto ) -> ir .Attr :
543+ return ir .from_proto (attrproto )
544+
545+ def make_attr_ref (self , attrname : str , refname : str , pytype : type ) -> ir .Attr :
546+ attr_type = ir .AttributeType (ta .pytype_to_attrtype (pytype ))
547+ return ir .Attr (attrname , attr_type , None , refname )
0 commit comments