Skip to content

Commit 447d340

Browse files
committed
callbacks working
1 parent 611d5ae commit 447d340

File tree

6 files changed

+589
-152
lines changed

6 files changed

+589
-152
lines changed

mlir_utils/dialects/ext/func.py

Lines changed: 151 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from typing import Union, Optional
23

34
from mlir.dialects.func import FuncOp, ReturnOp, CallOp
45
from mlir.ir import (
@@ -8,7 +9,7 @@
89
TypeAttr,
910
FlatSymbolRefAttr,
1011
Type,
11-
Location,
12+
Value,
1213
)
1314

1415
from mlir_utils.util import (
@@ -19,16 +20,90 @@
1920
)
2021

2122

23+
def call(
24+
callee_or_results: Union[FuncOp, list[Type]],
25+
arguments_or_callee: Union[list[Value], FlatSymbolRefAttr, str],
26+
arguments: Optional[list] = None,
27+
*,
28+
call_op_ctor=CallOp.__base__,
29+
loc=None,
30+
ip=None,
31+
):
32+
"""Creates an call operation.
33+
34+
The constructor accepts three different forms:
35+
36+
1. A function op to be called followed by a list of arguments.
37+
2. A list of result types, followed by the name of the function to be
38+
called as string, following by a list of arguments.
39+
3. A list of result types, followed by the name of the function to be
40+
called as symbol reference attribute, followed by a list of arguments.
41+
42+
For example
43+
44+
f = func.FuncOp("foo", ...)
45+
func.CallOp(f, [args])
46+
func.CallOp([result_types], "foo", [args])
47+
48+
In all cases, the location and insertion point may be specified as keyword
49+
arguments if not provided by the surrounding context managers.
50+
"""
51+
if loc is None:
52+
loc = get_user_code_loc()
53+
if isinstance(callee_or_results, FuncOp.__base__):
54+
if not isinstance(arguments_or_callee, (list, tuple)):
55+
raise ValueError(
56+
"when constructing a call to a function, expected "
57+
+ "the second argument to be a list of call arguments, "
58+
+ f"got {type(arguments_or_callee)}"
59+
)
60+
if arguments is not None:
61+
raise ValueError(
62+
"unexpected third argument when constructing a call" + "to a function"
63+
)
64+
return call_op_ctor(
65+
callee_or_results.function_type.value.results,
66+
FlatSymbolRefAttr.get(callee_or_results.sym_name.value),
67+
arguments_or_callee,
68+
loc=loc,
69+
ip=ip,
70+
)
71+
72+
if isinstance(arguments_or_callee, list):
73+
raise ValueError(
74+
"when constructing a call to a function by name, "
75+
+ "expected the second argument to be a string or a "
76+
+ f"FlatSymbolRefAttr, got {type(arguments_or_callee)}"
77+
)
78+
79+
if isinstance(arguments_or_callee, FlatSymbolRefAttr):
80+
return call_op_ctor(
81+
callee_or_results, arguments_or_callee, arguments, loc=loc, ip=ip
82+
)
83+
elif isinstance(arguments_or_callee, str):
84+
return call_op_ctor(
85+
callee_or_results,
86+
FlatSymbolRefAttr.get(arguments_or_callee),
87+
arguments,
88+
loc=loc,
89+
ip=ip,
90+
)
91+
else:
92+
raise ValueError(f"unexpected type {callee_or_results=}")
93+
94+
2295
class FuncBase:
2396
def __init__(
2497
self,
2598
body_builder,
2699
func_op_ctor,
27100
return_op_ctor,
28101
call_op_ctor,
102+
return_types=None,
29103
sym_visibility=None,
30104
arg_attrs=None,
31105
res_attrs=None,
106+
func_attrs=None,
32107
loc=None,
33108
ip=None,
34109
):
@@ -40,6 +115,13 @@ def __init__(
40115
self.body_builder = body_builder
41116
self.func_name = self.body_builder.__name__
42117

118+
if return_types is None:
119+
return_types = []
120+
sig = inspect.signature(self.body_builder)
121+
self.input_types, self.return_types, self.arg_locs = self.prep_func_types(
122+
sig, return_types
123+
)
124+
43125
self.func_op_ctor = func_op_ctor
44126
self.return_op_ctor = return_op_ctor
45127
self.call_op_ctor = call_op_ctor
@@ -48,26 +130,63 @@ def __init__(
48130
)
49131
self.arg_attrs = arg_attrs
50132
self.res_attrs = res_attrs
133+
if func_attrs is None:
134+
func_attrs = {}
135+
self.func_attrs = func_attrs
51136
self.loc = loc
52137
self.ip = ip or InsertionPoint.current
53-
self.emitted = False
138+
self._func_op = None
139+
140+
if self._is_decl():
141+
assert len(self.input_types) == len(
142+
sig.parameters
143+
), f"func decl needs all input types annotated"
144+
self.sym_visibility = StringAttr.get("private")
145+
self.emit()
146+
147+
def _is_decl(self):
148+
# magic constant found from looking at the code for an empty fn
149+
return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00"
54150

55151
def __str__(self):
56152
return str(f"{self.__class__} {self.__dict__}")
57153

154+
def prep_func_types(self, sig, return_types):
155+
assert not (
156+
not sig.return_annotation is inspect.Signature.empty
157+
and len(return_types) > 0
158+
), f"func can use return annotation or explicit return_types but not both"
159+
return_types = (
160+
sig.return_annotation
161+
if not sig.return_annotation is inspect.Signature.empty
162+
else return_types
163+
)
164+
if not isinstance(return_types, (tuple, list)):
165+
return_types = [return_types]
166+
return_types = list(return_types)
167+
assert all(
168+
isinstance(r, Type) for r in return_types
169+
), f"all return types must be mlir types {return_types=}"
170+
171+
input_types = [
172+
p.annotation
173+
for p in sig.parameters.values()
174+
if not p.annotation is inspect.Signature.empty
175+
]
176+
assert all(
177+
isinstance(r, Type) for r in input_types
178+
), f"all input types must be mlir types {input_types=}"
179+
return input_types, return_types, [get_user_code_loc()] * len(sig.parameters)
180+
58181
def body_builder_wrapper(self, *call_args):
59-
sig = inspect.signature(self.body_builder)
60-
implicit_return = sig.return_annotation is inspect._empty
61-
input_types = [p.annotation for p in sig.parameters.values()]
62-
if not (
63-
len(input_types) == len(sig.parameters)
64-
and all(isinstance(t, Type) for t in input_types)
65-
):
182+
if len(call_args) == 0:
183+
input_types = self.input_types
184+
else:
66185
input_types = [a.type for a in call_args]
67186
function_type = TypeAttr.get(
68187
FunctionType.get(
69188
inputs=input_types,
70-
results=[] if implicit_return else sig.return_annotation,
189+
results=self.return_types,
71190
)
72191
)
73192
func_op = self.func_op_ctor(
@@ -79,8 +198,10 @@ def body_builder_wrapper(self, *call_args):
79198
loc=self.loc,
80199
ip=self.ip,
81200
)
82-
arg_locs = [get_user_code_loc()] * len(sig.parameters)
83-
func_op.regions[0].blocks.append(*input_types, arg_locs=arg_locs)
201+
if self._is_decl():
202+
return self.return_types, input_types, func_op
203+
204+
func_op.regions[0].blocks.append(*input_types, arg_locs=self.arg_locs)
84205
with InsertionPoint(func_op.regions[0].blocks[0]):
85206
results = get_result_or_results(
86207
self.body_builder(
@@ -94,31 +215,23 @@ def body_builder_wrapper(self, *call_args):
94215
results = [results]
95216
else:
96217
results = []
218+
97219
self.return_op_ctor(results)
220+
return_types = [r.type for r in results]
221+
return return_types, input_types, func_op
98222

99-
return results, input_types, func_op
100-
101-
def emit(self):
102-
self.results, input_types, func_op = self.body_builder_wrapper()
103-
return_types = [v.type for v in self.results]
104-
function_type = FunctionType.get(inputs=input_types, results=return_types)
105-
func_op.attributes["function_type"] = TypeAttr.get(function_type)
106-
self.emitted = True
107-
# this is the func op itself (funcs never have a resulting ssa value)
108-
return maybe_cast(get_result_or_results(func_op))
109-
110-
def __call__(self, *call_args, loc: Location = None):
111-
if loc is None:
112-
loc = get_user_code_loc()
113-
if not self.emitted:
114-
self.emit()
115-
call_op = self.call_op_ctor(
116-
[r.type for r in self.results],
117-
FlatSymbolRefAttr.get(self.func_name),
118-
call_args,
119-
loc=loc,
120-
)
121-
return maybe_cast(get_result_or_results(call_op))
223+
def emit(self) -> FuncOp:
224+
if self._func_op is None:
225+
return_types, input_types, func_op = self.body_builder_wrapper()
226+
function_type = FunctionType.get(inputs=input_types, results=return_types)
227+
func_op.attributes["function_type"] = TypeAttr.get(function_type)
228+
for k, v in self.func_attrs.items():
229+
func_op.attributes[k] = v
230+
self._func_op = func_op
231+
return self._func_op
232+
233+
def __call__(self, *call_args):
234+
return call(self.emit(), call_args)
122235

123236

124237
@make_maybe_no_args_decorator
@@ -128,9 +241,10 @@ def func(
128241
sym_visibility=None,
129242
arg_attrs=None,
130243
res_attrs=None,
244+
func_attrs=None,
131245
loc=None,
132246
ip=None,
133-
):
247+
) -> FuncBase:
134248
if loc is None:
135249
loc = get_user_code_loc()
136250
return FuncBase(
@@ -141,48 +255,7 @@ def func(
141255
sym_visibility=sym_visibility,
142256
arg_attrs=arg_attrs,
143257
res_attrs=res_attrs,
258+
func_attrs=func_attrs,
144259
loc=loc,
145260
ip=ip,
146261
)
147-
148-
149-
def call(symbol_name, call_args, return_types, *, loc=None, ip=None):
150-
if loc is None:
151-
loc = get_user_code_loc()
152-
return maybe_cast(
153-
get_result_or_results(
154-
CallOp.__base__(
155-
return_types,
156-
FlatSymbolRefAttr.get(symbol_name),
157-
call_args,
158-
loc=loc,
159-
ip=ip,
160-
)
161-
)
162-
)
163-
164-
165-
def declare(
166-
symbol_name,
167-
input_types: list,
168-
result_types=None,
169-
func_op_ctor=FuncOp,
170-
):
171-
if result_types is None:
172-
result_types = []
173-
assert all(
174-
isinstance(a, Type) for a in input_types
175-
), f"wrong func args {input_types}"
176-
assert all(
177-
isinstance(a, Type) for a in result_types
178-
), f"wrong func results {result_types}"
179-
180-
function_type = FunctionType.get(inputs=input_types, results=result_types)
181-
sym_name = func_op_ctor(
182-
name=symbol_name, type=function_type, visibility="private"
183-
).sym_name
184-
185-
def callable(*call_args):
186-
return call(sym_name.value, call_args, result_types)
187-
188-
return callable

0 commit comments

Comments
 (0)