11import inspect
2+ from typing import Union , Optional
23
34from mlir .dialects .func import FuncOp , ReturnOp , CallOp
45from mlir .ir import (
89 TypeAttr ,
910 FlatSymbolRefAttr ,
1011 Type ,
11- Location ,
12+ Value ,
1213)
1314
1415from mlir_utils .util import (
1920)
2021
2122
23+ def call (
24+ callee_or_results : Union [FuncOp , list [Type ]],
25+ arguments_or_callee : Union [list [Value ], FlatSymbolRefAttr , str ],
26+ arguments : Optional [list ] = None ,
27+ * ,
28+ call_op_ctor = CallOp .__base__ ,
29+ loc = None ,
30+ ip = None ,
31+ ):
32+ """Creates an call operation.
33+
34+ The constructor accepts three different forms:
35+
36+ 1. A function op to be called followed by a list of arguments.
37+ 2. A list of result types, followed by the name of the function to be
38+ called as string, following by a list of arguments.
39+ 3. A list of result types, followed by the name of the function to be
40+ called as symbol reference attribute, followed by a list of arguments.
41+
42+ For example
43+
44+ f = func.FuncOp("foo", ...)
45+ func.CallOp(f, [args])
46+ func.CallOp([result_types], "foo", [args])
47+
48+ In all cases, the location and insertion point may be specified as keyword
49+ arguments if not provided by the surrounding context managers.
50+ """
51+ if loc is None :
52+ loc = get_user_code_loc ()
53+ if isinstance (callee_or_results , FuncOp .__base__ ):
54+ if not isinstance (arguments_or_callee , (list , tuple )):
55+ raise ValueError (
56+ "when constructing a call to a function, expected "
57+ + "the second argument to be a list of call arguments, "
58+ + f"got { type (arguments_or_callee )} "
59+ )
60+ if arguments is not None :
61+ raise ValueError (
62+ "unexpected third argument when constructing a call" + "to a function"
63+ )
64+ return call_op_ctor (
65+ callee_or_results .function_type .value .results ,
66+ FlatSymbolRefAttr .get (callee_or_results .sym_name .value ),
67+ arguments_or_callee ,
68+ loc = loc ,
69+ ip = ip ,
70+ )
71+
72+ if isinstance (arguments_or_callee , list ):
73+ raise ValueError (
74+ "when constructing a call to a function by name, "
75+ + "expected the second argument to be a string or a "
76+ + f"FlatSymbolRefAttr, got { type (arguments_or_callee )} "
77+ )
78+
79+ if isinstance (arguments_or_callee , FlatSymbolRefAttr ):
80+ return call_op_ctor (
81+ callee_or_results , arguments_or_callee , arguments , loc = loc , ip = ip
82+ )
83+ elif isinstance (arguments_or_callee , str ):
84+ return call_op_ctor (
85+ callee_or_results ,
86+ FlatSymbolRefAttr .get (arguments_or_callee ),
87+ arguments ,
88+ loc = loc ,
89+ ip = ip ,
90+ )
91+ else :
92+ raise ValueError (f"unexpected type { callee_or_results = } " )
93+
94+
2295class FuncBase :
2396 def __init__ (
2497 self ,
2598 body_builder ,
2699 func_op_ctor ,
27100 return_op_ctor ,
28101 call_op_ctor ,
102+ return_types = None ,
29103 sym_visibility = None ,
30104 arg_attrs = None ,
31105 res_attrs = None ,
106+ func_attrs = None ,
32107 loc = None ,
33108 ip = None ,
34109 ):
@@ -40,6 +115,13 @@ def __init__(
40115 self .body_builder = body_builder
41116 self .func_name = self .body_builder .__name__
42117
118+ if return_types is None :
119+ return_types = []
120+ sig = inspect .signature (self .body_builder )
121+ self .input_types , self .return_types , self .arg_locs = self .prep_func_types (
122+ sig , return_types
123+ )
124+
43125 self .func_op_ctor = func_op_ctor
44126 self .return_op_ctor = return_op_ctor
45127 self .call_op_ctor = call_op_ctor
@@ -48,26 +130,63 @@ def __init__(
48130 )
49131 self .arg_attrs = arg_attrs
50132 self .res_attrs = res_attrs
133+ if func_attrs is None :
134+ func_attrs = {}
135+ self .func_attrs = func_attrs
51136 self .loc = loc
52137 self .ip = ip or InsertionPoint .current
53- self .emitted = False
138+ self ._func_op = None
139+
140+ if self ._is_decl ():
141+ assert len (self .input_types ) == len (
142+ sig .parameters
143+ ), f"func decl needs all input types annotated"
144+ self .sym_visibility = StringAttr .get ("private" )
145+ self .emit ()
146+
147+ def _is_decl (self ):
148+ # magic constant found from looking at the code for an empty fn
149+ return self .body_builder .__code__ .co_code == b"\x97 \x00 d\x00 S\x00 "
54150
55151 def __str__ (self ):
56152 return str (f"{ self .__class__ } { self .__dict__ } " )
57153
154+ def prep_func_types (self , sig , return_types ):
155+ assert not (
156+ not sig .return_annotation is inspect .Signature .empty
157+ and len (return_types ) > 0
158+ ), f"func can use return annotation or explicit return_types but not both"
159+ return_types = (
160+ sig .return_annotation
161+ if not sig .return_annotation is inspect .Signature .empty
162+ else return_types
163+ )
164+ if not isinstance (return_types , (tuple , list )):
165+ return_types = [return_types ]
166+ return_types = list (return_types )
167+ assert all (
168+ isinstance (r , Type ) for r in return_types
169+ ), f"all return types must be mlir types { return_types = } "
170+
171+ input_types = [
172+ p .annotation
173+ for p in sig .parameters .values ()
174+ if not p .annotation is inspect .Signature .empty
175+ ]
176+ assert all (
177+ isinstance (r , Type ) for r in input_types
178+ ), f"all input types must be mlir types { input_types = } "
179+ return input_types , return_types , [get_user_code_loc ()] * len (sig .parameters )
180+
58181 def body_builder_wrapper (self , * call_args ):
59- sig = inspect .signature (self .body_builder )
60- implicit_return = sig .return_annotation is inspect ._empty
61- input_types = [p .annotation for p in sig .parameters .values ()]
62- if not (
63- len (input_types ) == len (sig .parameters )
64- and all (isinstance (t , Type ) for t in input_types )
65- ):
182+ if len (call_args ) == 0 :
183+ input_types = self .input_types
184+ else :
66185 input_types = [a .type for a in call_args ]
67186 function_type = TypeAttr .get (
68187 FunctionType .get (
69188 inputs = input_types ,
70- results = [] if implicit_return else sig . return_annotation ,
189+ results = self . return_types ,
71190 )
72191 )
73192 func_op = self .func_op_ctor (
@@ -79,8 +198,10 @@ def body_builder_wrapper(self, *call_args):
79198 loc = self .loc ,
80199 ip = self .ip ,
81200 )
82- arg_locs = [get_user_code_loc ()] * len (sig .parameters )
83- func_op .regions [0 ].blocks .append (* input_types , arg_locs = arg_locs )
201+ if self ._is_decl ():
202+ return self .return_types , input_types , func_op
203+
204+ func_op .regions [0 ].blocks .append (* input_types , arg_locs = self .arg_locs )
84205 with InsertionPoint (func_op .regions [0 ].blocks [0 ]):
85206 results = get_result_or_results (
86207 self .body_builder (
@@ -94,31 +215,23 @@ def body_builder_wrapper(self, *call_args):
94215 results = [results ]
95216 else :
96217 results = []
218+
97219 self .return_op_ctor (results )
220+ return_types = [r .type for r in results ]
221+ return return_types , input_types , func_op
98222
99- return results , input_types , func_op
100-
101- def emit (self ):
102- self .results , input_types , func_op = self .body_builder_wrapper ()
103- return_types = [v .type for v in self .results ]
104- function_type = FunctionType .get (inputs = input_types , results = return_types )
105- func_op .attributes ["function_type" ] = TypeAttr .get (function_type )
106- self .emitted = True
107- # this is the func op itself (funcs never have a resulting ssa value)
108- return maybe_cast (get_result_or_results (func_op ))
109-
110- def __call__ (self , * call_args , loc : Location = None ):
111- if loc is None :
112- loc = get_user_code_loc ()
113- if not self .emitted :
114- self .emit ()
115- call_op = self .call_op_ctor (
116- [r .type for r in self .results ],
117- FlatSymbolRefAttr .get (self .func_name ),
118- call_args ,
119- loc = loc ,
120- )
121- return maybe_cast (get_result_or_results (call_op ))
223+ def emit (self ) -> FuncOp :
224+ if self ._func_op is None :
225+ return_types , input_types , func_op = self .body_builder_wrapper ()
226+ function_type = FunctionType .get (inputs = input_types , results = return_types )
227+ func_op .attributes ["function_type" ] = TypeAttr .get (function_type )
228+ for k , v in self .func_attrs .items ():
229+ func_op .attributes [k ] = v
230+ self ._func_op = func_op
231+ return self ._func_op
232+
233+ def __call__ (self , * call_args ):
234+ return call (self .emit (), call_args )
122235
123236
124237@make_maybe_no_args_decorator
@@ -128,9 +241,10 @@ def func(
128241 sym_visibility = None ,
129242 arg_attrs = None ,
130243 res_attrs = None ,
244+ func_attrs = None ,
131245 loc = None ,
132246 ip = None ,
133- ):
247+ ) -> FuncBase :
134248 if loc is None :
135249 loc = get_user_code_loc ()
136250 return FuncBase (
@@ -141,48 +255,7 @@ def func(
141255 sym_visibility = sym_visibility ,
142256 arg_attrs = arg_attrs ,
143257 res_attrs = res_attrs ,
258+ func_attrs = func_attrs ,
144259 loc = loc ,
145260 ip = ip ,
146261 )
147-
148-
149- def call (symbol_name , call_args , return_types , * , loc = None , ip = None ):
150- if loc is None :
151- loc = get_user_code_loc ()
152- return maybe_cast (
153- get_result_or_results (
154- CallOp .__base__ (
155- return_types ,
156- FlatSymbolRefAttr .get (symbol_name ),
157- call_args ,
158- loc = loc ,
159- ip = ip ,
160- )
161- )
162- )
163-
164-
165- def declare (
166- symbol_name ,
167- input_types : list ,
168- result_types = None ,
169- func_op_ctor = FuncOp ,
170- ):
171- if result_types is None :
172- result_types = []
173- assert all (
174- isinstance (a , Type ) for a in input_types
175- ), f"wrong func args { input_types } "
176- assert all (
177- isinstance (a , Type ) for a in result_types
178- ), f"wrong func results { result_types } "
179-
180- function_type = FunctionType .get (inputs = input_types , results = result_types )
181- sym_name = func_op_ctor (
182- name = symbol_name , type = function_type , visibility = "private"
183- ).sym_name
184-
185- def callable (* call_args ):
186- return call (sym_name .value , call_args , result_types )
187-
188- return callable
0 commit comments