Skip to content

Commit 16d2a67

Browse files
authored
support user-specified func_type (#125)
1 parent f08db06 commit 16d2a67

File tree

3 files changed

+87
-29
lines changed

3 files changed

+87
-29
lines changed

mlir/extras/dialects/ext/arith.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,16 @@ def _arith_CmpIPredicateAttr(predicate: Union[str, Attribute], context: Context)
244244
"ule": CmpIPredicate.ule,
245245
"ugt": CmpIPredicate.ugt,
246246
"uge": CmpIPredicate.uge,
247+
0: CmpIPredicate.eq,
248+
1: CmpIPredicate.ne,
249+
2: CmpIPredicate.slt,
250+
3: CmpIPredicate.sle,
251+
4: CmpIPredicate.sgt,
252+
5: CmpIPredicate.sge,
253+
6: CmpIPredicate.ult,
254+
7: CmpIPredicate.ule,
255+
8: CmpIPredicate.ugt,
256+
9: CmpIPredicate.uge,
247257
}
248258
if isinstance(predicate, Attribute):
249259
return predicate
@@ -410,6 +420,9 @@ def literal_value(self):
410420
def coerce(self, other) -> Tuple["ArithValue", "ArithValue"]:
411421
pass
412422

423+
def __hash__(self):
424+
return Value(self).__hash__()
425+
413426
def fold(self) -> bool:
414427
return self.is_constant() and self._fold
415428

mlir/extras/dialects/ext/func.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,26 @@ def __init__(
144144
*,
145145
return_types=None,
146146
sym_visibility=None,
147+
sym_name=None,
147148
arg_attrs=None,
148149
res_attrs=None,
149150
func_attrs=None,
151+
function_type=None,
150152
generics: List[Union[TypeVar, ReifiedTypeParams]] = None,
151153
qualname=None,
152154
loc=None,
153155
ip=None,
154156
):
155157
assert inspect.isfunction(body_builder), body_builder
156158
assert inspect.isclass(func_op_ctor), func_op_ctor
157-
assert inspect.isclass(return_op_ctor), return_op_ctor
159+
if return_op_ctor is not None:
160+
assert inspect.isclass(return_op_ctor), return_op_ctor
158161
assert inspect.isclass(call_op_ctor), call_op_ctor
159162

160163
self.body_builder = body_builder
161-
self.func_name = self.body_builder.__name__
164+
if sym_name is None:
165+
sym_name = self.body_builder.__name__
166+
self.func_name = sym_name
162167
self.func_op_ctor = func_op_ctor
163168
self.return_op_ctor = return_op_ctor
164169
self.call_op_ctor = call_op_ctor
@@ -175,6 +180,7 @@ def __init__(
175180
self.func_attrs = func_attrs
176181
if self.func_attrs is None:
177182
self.func_attrs = {}
183+
self.function_type = function_type
178184

179185
if return_types is None:
180186
return_types = []
@@ -208,32 +214,37 @@ def __str__(self):
208214

209215
def emit(self, *call_args, decl=False, force=False) -> FuncOp:
210216
if self._func_op is None or decl or force:
211-
if len(call_args) == 0:
212-
input_types = self.input_types[:]
213-
locals = {"T": T}
214-
if self.generics is not None:
215-
for t in self.generics:
216-
if not isinstance(t, ReifiedTypeParams):
217-
raise RuntimeError(f"{t=} must reified")
218-
locals[t.name] = t.val
219-
for i, v in enumerate(input_types):
220-
if isinstance(v, TypeVar):
221-
v = v.__name__
222-
if isinstance(v, str):
223-
input_types[i] = Type(
224-
eval(v, self.body_builder.__globals__, locals)
225-
)
226-
elif isalambda(v):
227-
input_types[i] = v()
228-
else:
229-
input_types = [a.type for a in call_args]
217+
if self.function_type is None:
218+
if len(call_args) == 0:
219+
input_types = self.input_types[:]
220+
locals = {"T": T}
221+
if self.generics is not None:
222+
for t in self.generics:
223+
if not isinstance(t, ReifiedTypeParams):
224+
raise RuntimeError(f"{t=} must reified")
225+
locals[t.name] = t.val
226+
for i, v in enumerate(input_types):
227+
if isinstance(v, TypeVar):
228+
v = v.__name__
229+
if isinstance(v, str):
230+
input_types[i] = Type(
231+
eval(v, self.body_builder.__globals__, locals)
232+
)
233+
elif isalambda(v):
234+
input_types[i] = v()
235+
else:
236+
input_types = [a.type for a in call_args]
230237

231-
function_type = TypeAttr.get(
232-
FunctionType.get(
233-
inputs=input_types,
234-
results=self.return_types,
238+
function_type = TypeAttr.get(
239+
FunctionType.get(
240+
inputs=input_types,
241+
results=self.return_types,
242+
)
235243
)
236-
)
244+
else:
245+
input_types = self.function_type.inputs
246+
function_type = TypeAttr.get(self.function_type)
247+
237248
self._func_op = self.func_op_ctor(
238249
self.func_name,
239250
function_type,
@@ -264,10 +275,15 @@ def grab_results(*args):
264275
return_types.append(results.type)
265276
return results
266277

267-
builder_wrapper(grab_results)
278+
if self.function_type is None:
279+
builder_wrapper(grab_results)
280+
function_type = FunctionType.get(
281+
inputs=input_types, results=return_types
282+
)
283+
self._func_op.attributes["function_type"] = TypeAttr.get(function_type)
284+
else:
285+
builder_wrapper(self.body_builder)
268286

269-
function_type = FunctionType.get(inputs=input_types, results=return_types)
270-
self._func_op.attributes["function_type"] = TypeAttr.get(function_type)
271287
return self._func_op
272288

273289
def __call__(self, *call_args):
@@ -345,9 +361,11 @@ def func(
345361
f,
346362
*,
347363
sym_visibility=None,
364+
sym_name=None,
348365
arg_attrs=None,
349366
res_attrs=None,
350367
func_attrs=None,
368+
function_type=None,
351369
emit=False,
352370
generics=None,
353371
loc=None,
@@ -363,9 +381,11 @@ def func(
363381
return_op_ctor=ReturnOp,
364382
call_op_ctor=CallOp.__base__,
365383
sym_visibility=sym_visibility,
384+
sym_name=sym_name,
366385
arg_attrs=arg_attrs,
367386
res_attrs=res_attrs,
368387
func_attrs=func_attrs,
388+
function_type=function_type,
369389
generics=generics,
370390
loc=loc,
371391
ip=ip,

tests/test_func.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mlir.extras.dialects.ext.arith import constant
1414
from mlir.extras.dialects.ext.func import func
1515
from mlir.extras.dialects.ext import linalg, arith, scf, memref
16+
from mlir.ir import FunctionType
1617

1718
# noinspection PyUnresolvedReferences
1819
from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext
@@ -320,3 +321,27 @@ def demo_fun1():
320321
"""
321322
)
322323
filecheck(correct, tls.ctx.module)
324+
325+
326+
def test_explicit_function_type(ctx: MLIRContext):
327+
input_types = [T.i32(), T.i32()]
328+
result_types = [T.i32()]
329+
func_type = FunctionType.get(input_types, result_types)
330+
331+
@func(function_type=func_type)
332+
def demo_fun1(a, b):
333+
one = constant(1)
334+
return one
335+
336+
demo_fun1.emit()
337+
correct = dedent(
338+
"""\
339+
module {
340+
func.func @demo_fun1(%arg0: i32, %arg1: i32) -> i32 {
341+
%c1_i32 = arith.constant 1 : i32
342+
return %c1_i32 : i32
343+
}
344+
}
345+
"""
346+
)
347+
filecheck(correct, ctx.module)

0 commit comments

Comments
 (0)