Skip to content

Commit 0b02513

Browse files
committed
add trampolines and correct upstream binding loading
1 parent 719ca6d commit 0b02513

File tree

19 files changed

+367
-110
lines changed

19 files changed

+367
-110
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ on:
1313

1414
jobs:
1515

16-
test-with-configuration:
16+
test-against-torch-mlir-bindings:
1717

1818
runs-on: ${{ matrix.os }}
1919
continue-on-error: true
@@ -35,39 +35,12 @@ jobs:
3535
- name: Install and configure
3636
shell: bash
3737
run: |
38-
export PIP_EXTRA_INDEX_URL=https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
39-
pip install .[test] -v
40-
configure-mlir-utils -y mlir
38+
pip install .[torch-mlir-test] \
39+
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest \
40+
-f https://llvm.github.io/torch-mlir/package-index/
41+
pushd /tmp && python -m mlir_utils.__configuration -y torch_mlir && popd
4142
4243
- name: Test
4344
shell: bash
4445
run: |
4546
pytest tests
46-
47-
test-with-env:
48-
49-
runs-on: ${{ matrix.os }}
50-
51-
strategy:
52-
fail-fast: false
53-
matrix:
54-
os: [ ubuntu-22.04, macos-11 ]
55-
56-
steps:
57-
- name: Checkout
58-
uses: actions/checkout@v2
59-
60-
- name: Setup Python
61-
uses: actions/setup-python@v4
62-
with:
63-
python-version: "3.11"
64-
65-
- name: Install and configure
66-
run: |
67-
export PIP_EXTRA_INDEX_URL=https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
68-
pip install .[test] -v
69-
70-
- name: Test
71-
run: |
72-
export MLIR_PYTHON_PACKAGE_PREFIX=mlir
73-
pytest tests

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# mlir-utils
2+
3+
## Dev
4+
5+
```shell
6+
# you need setuptools >= 64 for build_editable
7+
pip install setuptools -U
8+
pip install -e ".[mlir-test]" \
9+
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
10+
```

examples/demo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from mlir_utils.dialects.generate_trampolines import generate_all_upstream_trampolines
2+
3+
generate_all_upstream_trampolines()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .configuration import *
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .configuration import *
2+
3+
configure_host_bindings()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import argparse
2+
import importlib
3+
import os
4+
import pkgutil
5+
import sys
6+
from pathlib import Path
7+
8+
9+
__MLIR_PYTHON_PACKAGE_PREFIX__ = None
10+
THIS_DIR = Path(__file__).resolve().parent
11+
MLIR_PYTHON_PACKAGE_PREFIX_FILE_PATH = THIS_DIR / "__MLIR_PYTHON_PACKAGE_PREFIX__"
12+
13+
14+
def import_submodules(package_name):
15+
package = sys.modules[package_name]
16+
return {
17+
name: importlib.import_module(package_name + "." + name)
18+
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__)
19+
}
20+
21+
22+
def load_upstream_bindings():
23+
global __MLIR_PYTHON_PACKAGE_PREFIX__
24+
25+
if MLIR_PYTHON_PACKAGE_PREFIX_FILE_PATH.exists():
26+
with open(MLIR_PYTHON_PACKAGE_PREFIX_FILE_PATH) as f:
27+
__MLIR_PYTHON_PACKAGE_PREFIX__ = f.read().strip()
28+
29+
if os.getenv("MLIR_PYTHON_PACKAGE_PREFIX"):
30+
__MLIR_PYTHON_PACKAGE_PREFIX__ = os.getenv("MLIR_PYTHON_PACKAGE_PREFIX")
31+
32+
if __MLIR_PYTHON_PACKAGE_PREFIX__ is not None:
33+
_mlir = sys.modules["mlir"] = __import__(
34+
__MLIR_PYTHON_PACKAGE_PREFIX__, globals(), locals(), fromlist=["*"]
35+
)
36+
for submod in ["ir", "dialects", "_mlir_libs"]:
37+
sys.modules[f"mlir.{submod}"] = __import__(
38+
f"{__MLIR_PYTHON_PACKAGE_PREFIX__}.{submod}",
39+
globals(),
40+
locals(),
41+
fromlist=["*"],
42+
)
43+
mlir_modules = {}
44+
for name, mod in sys.modules.items():
45+
if name.startswith(__MLIR_PYTHON_PACKAGE_PREFIX__ + "."):
46+
mlir_name = (
47+
"mlir." + name[len(__MLIR_PYTHON_PACKAGE_PREFIX__ + ".") + 1 :]
48+
)
49+
mlir_modules[mlir_name] = mod
50+
sys.modules.update(mlir_modules)
51+
52+
else:
53+
if not (
54+
sys.argv[0].endswith("configure-mlir-utils")
55+
or ("-m" in sys.orig_argv and "mlir_utils.__configuration" in sys.orig_argv)
56+
):
57+
raise Exception(
58+
"mlir-utils not configured and MLIR_PYTHON_PACKAGE_PREFIX env variable not set"
59+
)
60+
61+
62+
def configure_host_bindings():
63+
parser = argparse.ArgumentParser(
64+
prog="configure-mlir-utils",
65+
description="Configure mlir-utils",
66+
)
67+
parser.add_argument("-y", "--yes", action="store_true", default=False)
68+
parser.add_argument("mlir_python_package_prefix")
69+
args = parser.parse_args()
70+
mlir_python_package_prefix = args.mlir_python_package_prefix
71+
assert mlir_python_package_prefix, "missing mlir_python_package_prefix"
72+
mlir_python_package_prefix = (
73+
mlir_python_package_prefix.replace("'", "").replace('"', "").strip()
74+
)
75+
76+
if bool(__MLIR_PYTHON_PACKAGE_PREFIX__):
77+
print(
78+
f'mlir_python_package_prefix has already been set to "{__MLIR_PYTHON_PACKAGE_PREFIX__}"'
79+
)
80+
if not args.yes:
81+
answer = input("do you want to reset? [y/n]: ")
82+
if answer.lower() not in {"1", "true", "yes", "y"}:
83+
return
84+
85+
if not args.yes:
86+
answer = input(f"new {mlir_python_package_prefix=}; continue? [y/n]: ")
87+
if answer.lower() not in {"1", "true", "yes", "y"}:
88+
return
89+
else:
90+
print(f"new {mlir_python_package_prefix=}")
91+
92+
# check if valid package/module
93+
try:
94+
_host_bindings_mlir = __import__(f"{mlir_python_package_prefix}._mlir_libs")
95+
except (ImportError, ModuleNotFoundError) as e:
96+
print(f"couldn't import {mlir_python_package_prefix=} due to: {e}")
97+
raise e
98+
99+
with open(MLIR_PYTHON_PACKAGE_PREFIX_FILE_PATH, "w") as f:
100+
f.write(mlir_python_package_prefix)
101+
102+
load_upstream_bindings()
103+
104+
from ..dialects.generate_trampolines import generate_all_upstream_trampolines
105+
106+
generate_all_upstream_trampolines()

mlir_utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .__configuration import load_upstream_bindings
2+
3+
load_upstream_bindings()

mlir_utils/configuration.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

mlir_utils/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33
from typing import Optional
44

5-
from .configuration import _host_bindings_mlir as mlir
5+
import mlir
66

77

88
@dataclass
@@ -20,7 +20,7 @@ def mlir_mod_ctx(
2020
context: mlir.ir.Context = None,
2121
location: mlir.ir.Location = None,
2222
allow_unregistered_dialects=False,
23-
):
23+
) -> MLIRContext:
2424
if context is None:
2525
context = mlir.ir.Context()
2626
if allow_unregistered_dialects:

0 commit comments

Comments
 (0)