Skip to content

Commit 74621bb

Browse files
committed
fix upstream binding aliasing yet again
1 parent 323684b commit 74621bb

File tree

10 files changed

+242
-96
lines changed

10 files changed

+242
-96
lines changed

README.md

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,39 @@
11
# mlir-utils
22

3+
The missing pieces (as far as boilerplate reduction goes) of the upstream MLIR python bindings.
4+
5+
## Install
6+
7+
This package is meant to work in concert with the upstream bindings.
8+
Practically speaking that means you need to have *some* package installed that includes mlir python bindings.
9+
In addition, you have to do one of two things to **configure this package** (after installing it):
10+
11+
1. `$ configure-mlir-utils -y <MLIR_PYTHON_PACKAGE_PREFIX>`, where `MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) the
12+
package prefix for your chosen upstream bindings. So for example, for `torch-mlir`, you would
13+
execute `configure-mlir-utils torch_mlir`, since `torch-mlir`'s bindings are the root of the `torch-mlir` python
14+
package. **When in doubt about this prefix**, it is everything up until `ir` (e.g., as
15+
in `from torch_mlir import ir`).
16+
2. `$ export MLIR_PYTHON_PACKAGE_PREFIX=<MLIR_PYTHON_PACKAGE_PREFIX>`, i.e., you can set this string as an environment
17+
variable each time you use this package. Note, in this case, if you want to make use of the "op trampolines", you
18+
still need to run `generate_trampolines.generate_all_upstream_trampolines()` by hand at least once.
19+
20+
There is a convenience `extra_requires` in `pyproject.toml` such that you can do this:
21+
22+
```shell
23+
$ pip install .[mlir] -f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
24+
$ configure-mlir-utils mlir
25+
```
26+
27+
# Examples
28+
29+
Check out the [tests](tests).
30+
331
## Dev
432

533
```shell
634
# you need setuptools >= 64 for build_editable
735
pip install setuptools -U
8-
pip install -e ".[mlir-test]" \
9-
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
36+
pip install -e .[torch-mlir-test] \
37+
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest \
38+
-f https://llvm.github.io/torch-mlir/package-index/
1039
```

examples/demo.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

examples/throwaway.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import inspect
2+
from functools import wraps
3+
4+
from mlir_utils.context import mlir_mod_ctx
5+
from mlir_utils.dialects.generate_trampolines import generate_all_upstream_trampolines
6+
from mlir_utils.dialects.memref import alloca_scope, return_
7+
from mlir_utils.dialects.transform import foreach, yield_
8+
from mlir_utils.dialects import gpu
9+
from mlir_utils.dialects.ext import func
10+
11+
12+
from mlir_utils.dialects.util import constant
13+
14+
generate_all_upstream_trampolines()
15+
# # generate_all_upstream_trampolines()
16+
# from mlir.dialects.scf import WhileOp
17+
# from mlir.ir import InsertionPoint
18+
#
19+
#
20+
# # from mlir_utils.dialects.scf import execute_region, yield_
21+
#
22+
#
23+
# # def doublewrap(f):
24+
# # """
25+
# # a decorator decorator, allowing the decorator to be used as:
26+
# # @decorator(with, arguments, and=kwargs)
27+
# # or
28+
# # @decorator
29+
# # """
30+
# #
31+
# # @wraps(f)
32+
# # def new_dec(*args, **kwargs):
33+
# # if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
34+
# # # actual decorated function
35+
# # return f(args[0])
36+
# # else:
37+
# # # decorator arguments
38+
# # return lambda realf: f(realf, *args, **kwargs)
39+
# #
40+
# # return new_dec
41+
# #
42+
#
43+
#
44+
# with mlir_mod_ctx() as ctx:
45+
# one = constant(1)
46+
#
47+
# @func.func
48+
# def demo_fun1():
49+
# one = constant(1)
50+
# return
51+
#
52+
# demo_fun1()
53+
# ctx.module.operation.verify()
54+
#
55+
#
56+
# print(ctx.module)
57+
# ctx.module.operation.verify()
58+
# print(ctx.module)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import importlib
2+
import importlib.util
3+
import sys
4+
from importlib.abc import Loader, MetaPathFinder
5+
from importlib.machinery import ModuleSpec, PathFinder
6+
from types import ModuleType
7+
from typing import Mapping, Optional, Sequence, Union
8+
9+
# stolen from https://github.com/dagster-io/dagster/blob/master/python_modules/dagster/dagster/_module_alias_map.py
10+
11+
# The AliasedModuleFinder should be inserted in front of the built-in PathFinder.
12+
def get_meta_path_insertion_index() -> int:
13+
for i in range(len(sys.meta_path)):
14+
finder = sys.meta_path[i]
15+
if isinstance(finder, type) and issubclass(finder, PathFinder):
16+
return i
17+
raise Exception(
18+
"Could not find the built-in PathFinder in sys.meta_path-- cannot insert the"
19+
" AliasedModuleFinder"
20+
)
21+
22+
23+
# Key reference to understand the load process:
24+
# https://docs.python.org/3/reference/import.html#loading
25+
26+
27+
# While it is possible to override `Loader.create_module` to simply return the base module, this
28+
# is undesirable because the import system modifies the module's metadata attributes after
29+
# creation and outside of our control. This means that, if we simply had:
30+
#
31+
# def create_module(self, spec):
32+
# return importlib.import_module(self.base_spec_name)
33+
#
34+
# The returned base module would have its name etc modified, e.g. the already-loaded
35+
# `dagster._core` would be renamed to alias `dagster.core`. To avoid this, we let the import system
36+
# generate a module using default logic, then simply discard this module in `exec_module` (the final
37+
# step), where it is passed in. This is the point at which we swap in the base module, which we
38+
# obtain through `importlib.import_module` (like a standard import statetment, this will simply
39+
# returned the cached module from `sys.modules` if it has already been loaded). The swap is done by
40+
# simply replacing the dummy module (already stored in `sys.modules` outside of our control) with
41+
# the imported base module.
42+
class AliasedModuleLoader(Loader):
43+
def __init__(self, alias: str, base_spec: ModuleSpec):
44+
self.alias = alias
45+
self.base_spec = base_spec
46+
47+
def exec_module(self, _module: ModuleType) -> None:
48+
base_module = importlib.import_module(self.base_spec.name)
49+
sys.modules[self.alias] = base_module
50+
51+
def module_repr(self, module: ModuleType) -> str:
52+
assert self.base_spec.loader
53+
return self.base_spec.loader.module_repr(module)
54+
55+
56+
class AliasedModuleFinder(MetaPathFinder):
57+
def __init__(self, alias_map: Mapping[str, str]):
58+
self.alias_map = alias_map
59+
60+
def find_spec(
61+
self,
62+
fullname: str,
63+
_path: Optional[Sequence[Union[bytes, str]]] = None,
64+
_target: Optional[ModuleType] = None,
65+
) -> Optional[ModuleSpec]:
66+
head = next(
67+
(
68+
k
69+
for k in self.alias_map.keys()
70+
if fullname == k or fullname.startswith(k + ".")
71+
),
72+
None,
73+
)
74+
if head is not None:
75+
base_name = self.alias_map[head] + fullname[len(head) :]
76+
base_spec = importlib.util.find_spec(base_name)
77+
assert base_spec, f"Could not find module spec for {base_name}."
78+
return ModuleSpec(
79+
fullname,
80+
AliasedModuleLoader(fullname, base_spec),
81+
origin=base_spec.origin,
82+
is_package=base_spec.submodule_search_locations is not None,
83+
)
84+
else:
85+
return None

mlir_utils/__configuration/configuration.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
import argparse
2-
import importlib
32
import os
4-
import pkgutil
53
import sys
64
from pathlib import Path
75

6+
from mlir_utils.__configuration._module_alias_map import (
7+
get_meta_path_insertion_index,
8+
AliasedModuleFinder,
9+
)
810

9-
__MLIR_PYTHON_PACKAGE_PREFIX__ = None
1011
THIS_DIR = Path(__file__).resolve().parent
1112
MLIR_PYTHON_PACKAGE_PREFIX_FILE_PATH = THIS_DIR / "__MLIR_PYTHON_PACKAGE_PREFIX__"
12-
13-
14-
def import_submodules(package_name):
15-
package = sys.modules[package_name]
16-
return {
17-
name: importlib.import_module(package_name + "." + name)
18-
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__)
19-
}
13+
__MLIR_PYTHON_PACKAGE_PREFIX__ = None
2014

2115

2216
def load_upstream_bindings():
@@ -28,35 +22,18 @@ def load_upstream_bindings():
2822

2923
if os.getenv("MLIR_PYTHON_PACKAGE_PREFIX"):
3024
__MLIR_PYTHON_PACKAGE_PREFIX__ = os.getenv("MLIR_PYTHON_PACKAGE_PREFIX")
31-
3225
if __MLIR_PYTHON_PACKAGE_PREFIX__ is not None:
33-
_mlir = sys.modules["mlir"] = __import__(
34-
__MLIR_PYTHON_PACKAGE_PREFIX__, globals(), locals(), fromlist=["*"]
26+
sys.meta_path.insert(
27+
get_meta_path_insertion_index(),
28+
AliasedModuleFinder({"mlir": __MLIR_PYTHON_PACKAGE_PREFIX__}),
29+
)
30+
elif not (
31+
sys.argv[0].endswith("configure-mlir-utils")
32+
or ("-m" in sys.orig_argv and "mlir_utils.__configuration" in sys.orig_argv)
33+
):
34+
raise Exception(
35+
"mlir-utils not configured and MLIR_PYTHON_PACKAGE_PREFIX env variable not set"
3536
)
36-
for submod in ["ir", "dialects", "_mlir_libs"]:
37-
sys.modules[f"mlir.{submod}"] = __import__(
38-
f"{__MLIR_PYTHON_PACKAGE_PREFIX__}.{submod}",
39-
globals(),
40-
locals(),
41-
fromlist=["*"],
42-
)
43-
mlir_modules = {}
44-
for name, mod in sys.modules.items():
45-
if name.startswith(__MLIR_PYTHON_PACKAGE_PREFIX__ + "."):
46-
mlir_name = (
47-
"mlir." + name[len(__MLIR_PYTHON_PACKAGE_PREFIX__ + ".") + 1 :]
48-
)
49-
mlir_modules[mlir_name] = mod
50-
sys.modules.update(mlir_modules)
51-
52-
else:
53-
if not (
54-
sys.argv[0].endswith("configure-mlir-utils")
55-
or ("-m" in sys.orig_argv and "mlir_utils.__configuration" in sys.orig_argv)
56-
):
57-
raise Exception(
58-
"mlir-utils not configured and MLIR_PYTHON_PACKAGE_PREFIX env variable not set"
59-
)
6037

6138

6239
def configure_host_bindings():

mlir_utils/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33
from typing import Optional
44

5-
import mlir
5+
import mlir.ir
66

77

88
@dataclass

mlir_utils/dialects/ext/func.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import inspect
22
from functools import wraps
33

4-
from mlir.dialects.func import FuncOp, ReturnOp
5-
from mlir.ir import InsertionPoint, FunctionType, StringAttr, TypeAttr
4+
from mlir.dialects.func import FuncOp, ReturnOp, CallOp
5+
from mlir.ir import (
6+
InsertionPoint,
7+
FunctionType,
8+
StringAttr,
9+
TypeAttr,
10+
FlatSymbolRefAttr,
11+
)
612

713
from mlir_utils.dialects.util import (
814
get_result_or_results,
@@ -16,18 +22,20 @@ def func(sym_visibility=None, arg_attrs=None, res_attrs=None, loc=None, ip=None)
1622

1723
def builder_wrapper(body_builder):
1824
@wraps(body_builder)
19-
def wrapper(*args):
25+
def wrapper(*call_args):
2026
sig = inspect.signature(body_builder)
2127
implicit_return = sig.return_annotation is inspect._empty
28+
input_types = [a.type for a in call_args]
2229
function_type = TypeAttr.get(
2330
FunctionType.get(
24-
inputs=[a.type for a in args],
31+
inputs=input_types,
2532
results=[] if implicit_return else sig.return_annotation,
2633
)
2734
)
2835
# FuncOp is extended but we do really want the base
29-
op = FuncOp.__base__(
30-
body_builder.__name__,
36+
func_name = body_builder.__name__
37+
func_op = FuncOp.__base__(
38+
func_name,
3139
function_type,
3240
sym_visibility=StringAttr.get(str(sym_visibility))
3341
if sym_visibility is not None
@@ -37,19 +45,30 @@ def wrapper(*args):
3745
loc=loc,
3846
ip=ip,
3947
)
40-
op.regions[0].blocks.append(*[a.type for a in args])
41-
with InsertionPoint(op.regions[0].blocks[0]):
42-
r = get_result_or_results(
43-
body_builder(*op.regions[0].blocks[0].arguments)
48+
func_op.regions[0].blocks.append(*[a.type for a in call_args])
49+
with InsertionPoint(func_op.regions[0].blocks[0]):
50+
results = get_result_or_results(
51+
body_builder(*func_op.regions[0].blocks[0].arguments)
4452
)
45-
if r is not None:
46-
if isinstance(r, (tuple, list)):
47-
ReturnOp(list(r))
53+
if results is not None:
54+
if isinstance(results, (tuple, list)):
55+
results = list(results)
4856
else:
49-
ReturnOp([r])
57+
results = [results]
5058
else:
51-
ReturnOp([])
52-
return r
59+
results = []
60+
ReturnOp(results)
61+
# Recompute the function type.
62+
return_types = [v.type for v in results]
63+
function_type = FunctionType.get(inputs=input_types, results=return_types)
64+
func_op.attributes["function_type"] = TypeAttr.get(function_type)
65+
66+
call_op = CallOp(
67+
[r.type for r in results], FlatSymbolRefAttr.get(func_name), call_args
68+
)
69+
if results is None:
70+
return None
71+
return get_result_or_results(call_op)
5372

5473
# wrapper.op = op
5574
return wrapper

0 commit comments

Comments
 (0)