Skip to content

Commit 6d9143e

Browse files
committed
port over operator overloading from iree-sandbox
1 parent 9f09562 commit 6d9143e

File tree

13 files changed

+699
-96
lines changed

13 files changed

+699
-96
lines changed

.github/workflows/wheels.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,16 @@ jobs:
7070
allowUpdates: true
7171
replacesArtifacts: true
7272
makeLatest: true
73+
74+
- name: Release current commit
75+
uses: ncipollo/[email protected]
76+
with:
77+
artifacts: "dist/*.whl"
78+
bodyFile: body.md
79+
token: "${{ secrets.GITHUB_TOKEN }}"
80+
tag: "latest"
81+
name: "latest"
82+
removeArtifacts: false
83+
allowUpdates: true
84+
replacesArtifacts: true
85+
makeLatest: true

README.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,23 @@ The missing pieces (as far as boilerplate reduction goes) of the upstream MLIR p
44

55
## Install
66

7+
### TL;DR
8+
9+
```shell
10+
$ pip install .[mlir] -f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
11+
$ configure-mlir-utils mlir
12+
13+
or for maximum convenience
14+
15+
```shell
16+
$ pip install mlir-python-utils[mlir] \
17+
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
18+
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
19+
$ configure-mlir-utils mlir
20+
```
21+
22+
### Details
23+
724
This package is meant to work in concert with the upstream bindings.
825
Practically speaking that means you need to have *some* package installed that includes mlir python bindings.
926
In addition, you have to do one of two things to **configure this package** (after installing it):
@@ -19,11 +36,6 @@ In addition, you have to do one of two things to **configure this package** (aft
1936
2037
There is a convenience `extra_requires` in `pyproject.toml` such that you can do this:
2138
22-
```shell
23-
$ pip install .[mlir] -f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
24-
$ configure-mlir-utils mlir
25-
```
26-
2739
# Examples
2840
2941
Check out the [tests](tests).

examples/throwaway.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88
generate_all_upstream_trampolines,
99
)
1010
from mlir_utils._configuration.configuration import _add_file_to_sources_txt_file
11+
from mlir_utils.dialects.ext.tensor import Tensor, S
1112

1213
# _add_file_to_sources_txt_file(Path("_configuration/__MLIR_PYTHON_PACKAGE_PREFIX__"))
13-
generate_all_upstream_trampolines()
14+
# generate_all_upstream_trampolines()
1415
from mlir_utils.dialects.memref import alloca_scope, return_
15-
from mlir_utils.dialects.transform import foreach, yield_
16+
from mlir_utils.dialects.tensor import generate, yield_, rank
17+
from mlir_utils.dialects.transform import foreach
1618
from mlir_utils.dialects import gpu
1719
from mlir_utils.dialects.ext import func
20+
from mlir_utils.dialects.ext.arith import constant
21+
from mlir_utils.types import f64, index
1822

19-
20-
from mlir_utils.dialects.util import constant
21-
22-
# # generate_all_upstream_trampolines()
23+
generate_all_upstream_trampolines()
2324
# from mlir.dialects.scf import WhileOp
2425
# from mlir.ir import InsertionPoint
2526
#
@@ -48,19 +49,24 @@
4849
# #
4950
#
5051
#
51-
# with mlir_mod_ctx() as ctx:
52-
# one = constant(1)
53-
#
54-
# @func.func
55-
# def demo_fun1():
56-
# one = constant(1)
57-
# return
58-
#
59-
# demo_fun1()
60-
# ctx.module.operation.verify()
52+
with mlir_mod_ctx() as ctx:
53+
54+
one = constant(1, index)
55+
two = constant(2, index)
56+
57+
@generate(
58+
Tensor[(S, 3, S), f64], dynamic_extents=[one, two], block_args=[index] * 3
59+
)
60+
def demo_fun1(i, j, k):
61+
one = constant(1.0)
62+
yield_(one)
63+
64+
r = rank(demo_fun1)
65+
66+
print(ctx.module)
67+
ctx.module.operation.verify()
6168
#
6269
#
6370
# print(ctx.module)
64-
# ctx.module.operation.verify()
6571
# print(ctx.module)
6672
# from importlib.resources import files

mlir_utils/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
11
from ._configuration.configuration import alias_upstream_bindings
2+
import atexit
23

3-
alias_upstream_bindings()
4+
if alias_upstream_bindings():
5+
from mlir import ir
6+
7+
DefaultContext = ir.Context()
8+
# Push a default context onto the context stack at import time.
9+
DefaultContext.__enter__()
10+
DefaultContext.allow_unregistered_dialects = False
11+
12+
DefaultLocation = ir.Location.unknown()
13+
DefaultLocation.__enter__()
14+
15+
@atexit.register
16+
def __exit_ctxt():
17+
DefaultContext.__exit__(None, None, None)
18+
19+
@atexit.register
20+
def __exit_loc():
21+
DefaultLocation.__exit__(None, None, None)

mlir_utils/_configuration/configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ def _add_file_to_sources_txt_file(file_path: Path):
3636
sources_file.write(
3737
f"{relative_file_path},sha256={encoded[:-1].decode()},{len(file)}\n"
3838
)
39-
else:
40-
raise RuntimeError("unsupported distribution scheme; please file a bug!")
4139

4240

4341
def _get_mlir_package_prefix():
@@ -58,13 +56,15 @@ def alias_upstream_bindings():
5856
get_meta_path_insertion_index(),
5957
AliasedModuleFinder({"mlir": mlir_python_package_prefix}),
6058
)
59+
return True
6160
elif not (
6261
sys.argv[0].endswith("configure-mlir-utils")
6362
or ("-m" in sys.orig_argv and __package__ in sys.orig_argv)
6463
):
6564
raise Exception(
6665
"mlir-utils not configured and MLIR_PYTHON_PACKAGE_PREFIX env variable not set"
6766
)
67+
return False
6868

6969

7070
def configure_host_bindings():

mlir_utils/_configuration/generate_trampolines.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,48 @@ def ast_call(name, args=None, keywords=None):
2323
)
2424

2525

26+
class FindOperands(ast.NodeVisitor):
27+
def __init__(self):
28+
self.operands = {}
29+
self.results = {}
30+
31+
def visit_Call(self, node: ast.Call):
32+
if hasattr(node.func, "value") and hasattr(node.func.value, "id"):
33+
if node.func.value.id == "operands":
34+
if isinstance(node.args[0], ast.Call):
35+
nested_call = node.args[0]
36+
is_optional = False
37+
elif isinstance(node.args[0], ast.IfExp):
38+
nested_call = node.args[0].body
39+
is_optional = True
40+
else:
41+
raise RuntimeError(
42+
f"unsupported operands python code: {ast.unparse(node)}"
43+
)
44+
oper_name = inflection.underscore(nested_call.args[0].id).lower()
45+
is_variadic = "values" in nested_call.func.id
46+
type = "list[Value]" if is_variadic else "Value"
47+
if is_optional:
48+
type = f"Optional[{type}]"
49+
self.operands[oper_name] = type
50+
elif node.func.value.id == "results":
51+
if node.func.attr == "extend":
52+
if isinstance(node.args[0], ast.BinOp):
53+
# something like results.extend([operands[0].type] * 1)
54+
return
55+
else:
56+
self.results[node.args[0].id] = "list[Type]"
57+
elif node.func.attr == "append":
58+
self.results[node.args[0].id] = "Type"
59+
else:
60+
raise ValueError("unknown results object")
61+
62+
2663
# TODO(max): ops that have symboltables need to be classes but that requires some upstream support for statically
2764
# identifying such ops
2865
def generate_op_trampoline(op_class):
66+
from mlir_utils.dialects.util import get_result_or_results, maybe_cast, region_op
67+
2968
_mod = ast.parse(dedent(inspect.getsource(op_class.__init__)))
3069
init_fn = next(n for n in _mod.body if isinstance(n, ast.FunctionDef))
3170
args = init_fn.args
@@ -41,8 +80,6 @@ def generate_op_trampoline(op_class):
4180
for k, d in zip(args.kwonlyargs, args.kw_defaults)
4281
]
4382

44-
for a in args.args + args.kwonlyargs:
45-
a.annotation = None
4683
fun_name = op_class.OPERATION_NAME.split(".")[-1]
4784
if keyword.iskeyword(fun_name):
4885
fun_name = fun_name + "_"
@@ -56,18 +93,27 @@ def generate_op_trampoline(op_class):
5693
and op_class._ODS_REGIONS[0] == 1
5794
and not op_class.OPERATION_NAME.startswith("linalg")
5895
):
59-
decorator_list = [ast.Name(id="region_op", ctx=ast.Load())]
96+
decorator_list = [ast.Name(id=region_op.__name__, ctx=ast.Load())]
6097
body += [ast.Return([ast_call(op_class_name, args.args, keywords)])]
6198
else:
6299
decorator_list = []
63100
body += [
64101
ast.parse(
65-
f"return get_result_or_results({ast.unparse(ast_call(op_class_name, args.args, keywords))})"
102+
f"return {maybe_cast.__name__}({get_result_or_results.__name__}({ast.unparse(ast_call(op_class_name, args.args, keywords))}))"
66103
).body[0]
67104
]
105+
106+
args = copy.deepcopy(args)
107+
oper_finder = FindOperands()
108+
oper_finder.visit(init_fn)
109+
for a in args.args:
110+
if a.arg in oper_finder.operands:
111+
a.annotation = ast.Name(id=oper_finder.operands[a.arg], ctx=ast.Load())
112+
elif a.arg in oper_finder.results:
113+
a.annotation = ast.Name(id=oper_finder.results[a.arg], ctx=ast.Load())
68114
n = ast.FunctionDef(
69115
name=fun_name,
70-
args=copy.deepcopy(args),
116+
args=args,
71117
body=body,
72118
decorator_list=decorator_list,
73119
)
@@ -77,8 +123,9 @@ def generate_op_trampoline(op_class):
77123

78124
def generate_dialect_trampolines_from_module(input_module, skips: set):
79125
import mlir_utils
80-
from mlir_utils.dialects.util import get_result_or_results
126+
from mlir_utils.dialects.util import get_result_or_results, maybe_cast, region_op
81127
import mlir.dialects._ods_common
128+
from mlir_utils._configuration.configuration import _get_mlir_package_prefix
82129

83130
skips.update({"_Dialect"})
84131
init_funs = {}
@@ -92,6 +139,7 @@ def generate_dialect_trampolines_from_module(input_module, skips: set):
92139
# these are extension classes and we should wrap the generated class instead
93140
obj = obj.__base__
94141
if not inspect.isfunction(obj.__init__):
142+
print(f"skipping {obj.__name__} because it has no __init__")
95143
# some builders don't have any __init__ but inherit from opview
96144
continue
97145
init_funs[obj.__name__] = obj
@@ -104,9 +152,17 @@ def generate_dialect_trampolines_from_module(input_module, skips: set):
104152
for op_class in sorted(init_funs.values(), key=lambda o: o.__name__)
105153
]
106154

155+
ir_imports = ast.ImportFrom(
156+
module=_get_mlir_package_prefix() + ".ir",
157+
names=[ast.alias(i) for i in ["Value", "Attribute", "Type"]],
158+
level=0,
159+
)
107160
ods_imports = ast.ImportFrom(
108161
module=mlir_utils.dialects.util.__name__,
109-
names=[ast.alias(get_result_or_results.__name__), ast.alias("region_op")],
162+
names=[
163+
ast.alias(f.__name__)
164+
for f in [get_result_or_results, maybe_cast, region_op]
165+
],
110166
level=0,
111167
)
112168
op_imports = ast.ImportFrom(
@@ -125,7 +181,11 @@ def generate_dialect_trampolines_from_module(input_module, skips: set):
125181
else:
126182
linalg_imports = []
127183

128-
new_mod = ast.Module([op_imports, *linalg_imports, ods_imports] + functions, [])
184+
all = ast.parse(f"__all__ = [{', '.join(repr(f.name) for f in functions)}]")
185+
186+
new_mod = ast.Module(
187+
[ir_imports, op_imports, *linalg_imports, ods_imports] + functions + [all], []
188+
)
129189
new_src = ast.unparse(new_mod)
130190
return black.format_file_contents(new_src, fast=False, mode=black.Mode())
131191

0 commit comments

Comments
 (0)