@@ -31,63 +31,42 @@ def select_ir_version(version: int, domain: str = "") -> int:
3131TypeAnnotationValue = onnxscript .type_annotation .TypeAnnotationValue
3232
3333
34- class IRFunction :
34+ class IRFunction ( ir . Function ) :
3535 """Represents a function in the IR."""
3636
3737 def __init__ (self , name : str , domain : str = "" ) -> None :
3838 graph = ir .Graph (inputs = [], outputs = [], nodes = [], name = name )
39- self . ir_function = ir . Function (domain , name , graph = graph , attributes = [])
39+ super (). __init__ (domain , name , graph = graph , attributes = [])
4040 self .ordered_inputs_and_attrs : list [Union [ir .Value , ir .Attr ]] = []
4141
4242 # a dictionary of nested function-definitions
4343 self .nested_functions : dict [str , IRFunction ] = {}
4444 self .outer_scope_variables : dict [Any , Any ] = {}
4545
46- @property
47- def outputs (self ) -> Sequence [ir .Value ]:
48- return self .ir_function .outputs
49-
50- @property
51- def domain (self ) -> str :
52- """Returns the domain of this function."""
53- return self .ir_function .domain
54-
5546 @property
5647 def docstring (self ) -> str :
5748 """Returns the docstring of this function."""
58- return self .ir_function .doc_string or ""
59-
60- @property
61- def name (self ) -> str :
62- """Returns the name of this function."""
63- return self .ir_function .name
49+ return self .doc_string or ""
6450
6551 @property
6652 def assigned_names (self ) -> Sequence [str ]:
6753 """Returns the list of variables assigned to by this function."""
68- return [v .name for n in self .ir_function for v in n .outputs ]
69-
70- @property
71- def inputs (self ) -> Sequence [ir .Value ]:
72- return self .ir_function .inputs
54+ return [v .name for n in self for v in n .outputs ]
7355
7456 @property
7557 def attrs (self ) -> Sequence [ir .Attr ]:
7658 return [attr for attr in self .ordered_inputs_and_attrs if isinstance (attr , ir .Attr )]
7759
78- def __str__ (self ):
79- return str (self .ir_function )
80-
8160 def append_node (self , node : ir .Node ) -> None :
82- count = len (self . ir_function )
61+ count = len (self )
8362 node .name = f"n{ count } "
84- self .ir_function . append (node )
63+ self .append (node )
8564 domain = node .domain
8665 version = node .version
87- if domain not in self .ir_function . opset_imports :
88- self .ir_function . opset_imports [domain ] = version
66+ if domain not in self .opset_imports :
67+ self .opset_imports [domain ] = version
8968 else :
90- existing_version = self .ir_function . opset_imports [domain ]
69+ existing_version = self .opset_imports [domain ]
9170 if existing_version != version :
9271 warnings .warn (
9372 f"Version conflict: domain: { domain !r} , "
@@ -98,14 +77,14 @@ def append_node(self, node: ir.Node) -> None:
9877
9978 def append_input (self , var : ir .Value ) -> None :
10079 self .ordered_inputs_and_attrs .append (var )
101- self .ir_function . inputs .append (var )
80+ self .inputs .append (var )
10281
10382 def append_output (self , var : ir .Value ) -> None :
104- self .ir_function . outputs .append (var )
83+ self .outputs .append (var )
10584
10685 def add_attr_parameter (self , attr : ir .Attr ) -> None :
10786 self .ordered_inputs_and_attrs .append (attr )
108- self .ir_function . attributes .add (attr )
87+ self .attributes .add (attr )
10988
11089 def add_nested_function (self , fun : IRFunction ) -> None :
11190 self .nested_functions [fun .name ] = fun
@@ -114,7 +93,7 @@ def get_called_functions(self) -> dict[str, onnx.FunctionProto]:
11493 called_functions : dict [str , values .OnnxFunction ] = {}
11594
11695 def visit (function_ir : IRFunction ):
117- for node in ir .traversal .RecursiveGraphIterator (function_ir .ir_function . graph ):
96+ for node in ir .traversal .RecursiveGraphIterator (function_ir .graph ):
11897 callee = node .meta .get ("callee" , None )
11998 if isinstance (callee , values .OnnxFunction ):
12099 add (callee )
@@ -139,11 +118,11 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto:
139118 an instance of :class:`onnx.GraphProto`
140119 """
141120 del use_default_type # currently not used
142- return ir .to_proto (self .ir_function . graph )
121+ return ir .to_proto (self .graph )
143122
144123 def to_function_proto (self ) -> onnx .FunctionProto :
145124 """Converts this instance into a `onnx.FunctionProto`."""
146- return ir .to_proto (self . ir_function )
125+ return ir .to_proto (self )
147126
148127
149128# IRBuilder: abstracts out details of the IR in the python-to-IR converter
@@ -180,7 +159,7 @@ def new_function(self, name: str, domain: str = "", register: bool = False) -> I
180159 return function
181160
182161 def add_docstring (self , fn : IRFunction , docstring : str ):
183- fn .ir_function . doc_string = docstring
162+ fn .doc_string = docstring
184163
185164 def add_stmt (
186165 self ,
0 commit comments