Skip to content

Commit 323684b

Browse files
committed
add region op decorator
1 parent 0b02513 commit 323684b

File tree

8 files changed

+305
-36
lines changed

8 files changed

+305
-36
lines changed

mlir_utils/dialects/.gitignore

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
arith.py
2-
async_dialect.py
3-
bufferization.py
4-
builtin.py
5-
cf.py
6-
complex.py
7-
func.py
8-
gpu.py
9-
linalg.py
10-
math.py
11-
memref.py
12-
ml_program.py
13-
pdl.py
14-
quant.py
15-
scf.py
16-
shape.py
17-
sparse_tensor.py
18-
tensor.py
19-
torch.py
20-
tosa.py
21-
transform.py
22-
vector.py
1+
/arith.py
2+
/async_dialect.py
3+
/bufferization.py
4+
/builtin.py
5+
/cf.py
6+
/complex.py
7+
/func.py
8+
/gpu.py
9+
/linalg.py
10+
/math.py
11+
/memref.py
12+
/ml_program.py
13+
/pdl.py
14+
/quant.py
15+
/scf.py
16+
/shape.py
17+
/sparse_tensor.py
18+
/tensor.py
19+
/torch.py
20+
/tosa.py
21+
/transform.py
22+
/vector.py

mlir_utils/dialects/ext/__init__.py

Whitespace-only changes.

mlir_utils/dialects/ext/func.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import inspect
2+
from functools import wraps
3+
4+
from mlir.dialects.func import FuncOp, ReturnOp
5+
from mlir.ir import InsertionPoint, FunctionType, StringAttr, TypeAttr
6+
7+
from mlir_utils.dialects.util import (
8+
get_result_or_results,
9+
make_maybe_no_args_decorator,
10+
)
11+
12+
13+
@make_maybe_no_args_decorator
14+
def func(sym_visibility=None, arg_attrs=None, res_attrs=None, loc=None, ip=None):
15+
ip = ip or InsertionPoint.current
16+
17+
def builder_wrapper(body_builder):
18+
@wraps(body_builder)
19+
def wrapper(*args):
20+
sig = inspect.signature(body_builder)
21+
implicit_return = sig.return_annotation is inspect._empty
22+
function_type = TypeAttr.get(
23+
FunctionType.get(
24+
inputs=[a.type for a in args],
25+
results=[] if implicit_return else sig.return_annotation,
26+
)
27+
)
28+
# FuncOp is extended but we do really want the base
29+
op = FuncOp.__base__(
30+
body_builder.__name__,
31+
function_type,
32+
sym_visibility=StringAttr.get(str(sym_visibility))
33+
if sym_visibility is not None
34+
else None,
35+
arg_attrs=arg_attrs,
36+
res_attrs=res_attrs,
37+
loc=loc,
38+
ip=ip,
39+
)
40+
op.regions[0].blocks.append(*[a.type for a in args])
41+
with InsertionPoint(op.regions[0].blocks[0]):
42+
r = get_result_or_results(
43+
body_builder(*op.regions[0].blocks[0].arguments)
44+
)
45+
if r is not None:
46+
if isinstance(r, (tuple, list)):
47+
ReturnOp(list(r))
48+
else:
49+
ReturnOp([r])
50+
else:
51+
ReturnOp([])
52+
return r
53+
54+
# wrapper.op = op
55+
return wrapper
56+
57+
return builder_wrapper

mlir_utils/dialects/generate_trampolines.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def ast_call(name, args=None, keywords=None):
2121
)
2222

2323

24-
def make_fun(op_name, op_class):
24+
# TODO(max): ops that have symboltables need to be classes but that requires some upstream support for statically
25+
# identifying such ops
26+
def generate_free_fun(op_class):
2527
_mod = ast.parse(dedent(inspect.getsource(op_class.__init__)))
2628
init_fn = next(n for n in _mod.body if isinstance(n, ast.FunctionDef))
2729
args = init_fn.args
@@ -39,17 +41,33 @@ def make_fun(op_name, op_class):
3941

4042
for a in args.args + args.kwonlyargs:
4143
a.annotation = None
42-
ret = ast.parse(
43-
f"return get_result_or_results({ast.unparse(ast_call(op_name, args.args, keywords))})"
44-
).body[0]
4544
fun_name = op_class.OPERATION_NAME.split(".")[-1]
4645
if keyword.iskeyword(fun_name):
4746
fun_name = fun_name + "_"
47+
op_class_name = op_class.__name__
48+
body = []
49+
if len(args.args) == 1 and args.args[0].arg == "results_":
50+
args.defaults.append(ast.Constant(None))
51+
body += [ast.parse("results_ = results_ or []").body[0]]
52+
if (
53+
hasattr(op_class, "_ODS_REGIONS")
54+
and op_class._ODS_REGIONS[0] == 1
55+
and not op_class.OPERATION_NAME.startswith("linalg")
56+
):
57+
decorator_list = [ast.Name(id="region_op", ctx=ast.Load())]
58+
body += [ast.Return([ast_call(op_class_name, args.args, keywords)])]
59+
else:
60+
decorator_list = []
61+
body += [
62+
ast.parse(
63+
f"return get_result_or_results({ast.unparse(ast_call(op_class_name, args.args, keywords))})"
64+
).body[0]
65+
]
4866
n = ast.FunctionDef(
4967
name=fun_name,
5068
args=copy.deepcopy(args),
51-
body=[ret],
52-
decorator_list=[],
69+
body=body,
70+
decorator_list=decorator_list,
5371
)
5472
ast.fix_missing_locations(n)
5573
return n
@@ -58,6 +76,7 @@ def make_fun(op_name, op_class):
5876
def generate_trampoline(input_module, output_file_path, skips=None):
5977
import mlir_utils
6078
from mlir_utils.dialects.util import get_result_or_results
79+
import mlir.dialects._ods_common
6180

6281
if skips is None:
6382
skips = set()
@@ -69,6 +88,9 @@ def generate_trampoline(input_module, output_file_path, skips=None):
6988
and hasattr(obj, "OPERATION_NAME")
7089
and obj.__name__ not in skips
7190
):
91+
if obj.__module__ == mlir.dialects._ods_common.__name__:
92+
# these are extension classes and we should wrap the generated class instead
93+
obj = obj.__base__
7294
if not inspect.isfunction(obj.__init__):
7395
# some builders don't have any __init__ but inherit from opview
7496
continue
@@ -77,11 +99,14 @@ def generate_trampoline(input_module, output_file_path, skips=None):
7799
if not len(init_funs):
78100
return
79101

80-
functions = [make_fun(n, s) for n, s in init_funs.items()]
102+
functions = [
103+
generate_free_fun(op_class)
104+
for op_class in sorted(init_funs.values(), key=lambda o: o.__name__)
105+
]
81106

82107
ods_imports = ast.ImportFrom(
83108
module=mlir_utils.dialects.util.__name__,
84-
names=[ast.alias(get_result_or_results.__name__)],
109+
names=[ast.alias(get_result_or_results.__name__), ast.alias("region_op")],
85110
level=0,
86111
)
87112
op_imports = ast.ImportFrom(

mlir_utils/dialects/util.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,93 @@
1+
from functools import wraps
2+
3+
import numpy as np
4+
from mlir.dialects import arith
15
from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
6+
from mlir.ir import (
7+
InsertionPoint,
8+
IntegerType,
9+
F64Type,
10+
RankedTensorType,
11+
IndexType,
12+
F16Type,
13+
F32Type,
14+
)
215

316

417
def get_result_or_results(op):
18+
if op is None:
19+
return
520
return (
621
get_op_results_or_values(op)
722
if len(op.operation.results) > 1
823
else get_op_result_or_value(op)
924
if len(op.operation.results) > 0
1025
else None
1126
)
27+
28+
29+
def make_maybe_no_args_decorator(decorator):
30+
@wraps(decorator)
31+
def maybe_no_args(*args, **kwargs):
32+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
33+
return decorator()(args[0])
34+
else:
35+
return decorator(*args, **kwargs)
36+
37+
return maybe_no_args
38+
39+
40+
# builds the decorator
41+
def region_op(op_constructor):
42+
# the decorator itself
43+
def op_decorator(*args, **kwargs):
44+
op = op_constructor(*args, **kwargs)
45+
46+
def builder_wrapper(body_builder):
47+
@wraps(body_builder)
48+
def wrapper(*args):
49+
# add a block with block args having types ...
50+
op.regions[0].blocks.append(*[a.type for a in args])
51+
with InsertionPoint(op.regions[0].blocks[0]):
52+
return get_result_or_results(
53+
body_builder(*op.regions[0].blocks[0].arguments)
54+
)
55+
56+
wrapper.op = op
57+
return wrapper
58+
59+
return builder_wrapper
60+
61+
return make_maybe_no_args_decorator(op_decorator)
62+
63+
64+
def infer_mlir_type(
65+
py_val: int | float | bool | np.ndarray,
66+
) -> IntegerType | F64Type | RankedTensorType:
67+
if isinstance(py_val, bool):
68+
return IntegerType.get_signless(1)
69+
elif isinstance(py_val, int):
70+
return IntegerType.get_signless(64)
71+
elif isinstance(py_val, float):
72+
return F64Type.get()
73+
elif isinstance(py_val, np.ndarray):
74+
dtype_ = {
75+
np.int8: IntegerType.get_signless(8),
76+
np.int16: IntegerType.get_signless(16),
77+
np.int32: IntegerType.get_signless(32),
78+
np.int64: IntegerType.get_signless(64),
79+
np.uintp: IndexType.get(),
80+
np.longlong: IndexType.get(),
81+
np.float16: F16Type.get(),
82+
np.float32: F32Type.get(),
83+
np.float64: F64Type.get(),
84+
}[py_val.dtype.type]
85+
return RankedTensorType.get(py_val.shape, dtype_)
86+
else:
87+
raise NotImplementedError(
88+
f"Unsupported Python value {py_val=} with type {type(py_val)}"
89+
)
90+
91+
92+
def constant(py_val: int | float | bool | np.ndarray):
93+
return arith.ConstantOp(infer_mlir_type(py_val), py_val)

mlir_utils/testing/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def filecheck(correct: str, module):
2525
assert Path(filecheck_path).exists() is not None, "couldn't find FileCheck"
2626

2727
correct = dedent(correct)
28+
correct_with_checks = main(correct).replace("CHECK:", "CHECK-NEXT:")
2829
op = dedent(str(module).strip())
2930
with tempfile.NamedTemporaryFile() as tmp:
30-
correct_with_checks = main(correct)
3131
tmp.write(correct_with_checks.encode())
3232
tmp.flush()
3333
p = Popen([filecheck_path, tmp.name], stdout=PIPE, stdin=PIPE, stderr=PIPE)
3434
out, err = map(lambda o: o.decode(), p.communicate(input=op.encode()))
35-
if len(err):
35+
if p.returncode:
3636
raise ValueError(err)
3737

3838

tests/test_regions.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from textwrap import dedent
2+
3+
import pytest
4+
5+
from mlir_utils.dialects.ext.func import func
6+
from mlir_utils.dialects.memref import alloca_scope, return_
7+
from mlir_utils.dialects.scf import execute_region, yield_
8+
from mlir_utils.dialects.util import constant
9+
10+
# noinspection PyUnresolvedReferences
11+
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
12+
13+
# needed since the fix isn't defined here nor conftest.py
14+
pytest.mark.usefixtures("ctx")
15+
16+
17+
def test_simple_region_op(ctx: MLIRContext):
18+
@execute_region([])
19+
def demo_region():
20+
one = constant(1)
21+
yield_()
22+
23+
demo_region()
24+
25+
ctx.module.operation.verify()
26+
filecheck(
27+
dedent(
28+
"""\
29+
module {
30+
scf.execute_region {
31+
%c1_i64 = arith.constant 1 : i64
32+
scf.yield
33+
}
34+
}
35+
"""
36+
),
37+
ctx.module,
38+
)
39+
40+
41+
def test_no_args_decorator(ctx: MLIRContext):
42+
@alloca_scope([])
43+
def demo_scope1():
44+
one = constant(1)
45+
return_()
46+
47+
@alloca_scope
48+
def demo_scope2():
49+
one = constant(2)
50+
return_()
51+
52+
demo_scope1()
53+
demo_scope2()
54+
55+
ctx.module.operation.verify()
56+
filecheck(
57+
dedent(
58+
"""\
59+
module {
60+
memref.alloca_scope {
61+
%c1_i64 = arith.constant 1 : i64
62+
}
63+
memref.alloca_scope {
64+
%c2_i64 = arith.constant 2 : i64
65+
}
66+
}
67+
"""
68+
),
69+
ctx.module,
70+
)
71+
72+
73+
def test_func(ctx: MLIRContext):
74+
@func
75+
def demo_fun1():
76+
one = constant(1)
77+
return
78+
79+
demo_fun1()
80+
ctx.module.operation.verify()
81+
filecheck(
82+
dedent(
83+
"""\
84+
module {
85+
func.func @demo_fun1() {
86+
%c1_i64 = arith.constant 1 : i64
87+
return
88+
}
89+
}
90+
"""
91+
),
92+
ctx.module,
93+
)

0 commit comments

Comments
 (0)