Skip to content

Commit 80af1ea

Browse files
committed
add jax as a test of upstream bindings
1 parent 3dee9c5 commit 80af1ea

File tree

6 files changed

+102
-10
lines changed

6 files changed

+102
-10
lines changed

.github/workflows/test.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,40 @@ jobs:
4949
shell: bash
5050
run: |
5151
pytest tests
52+
53+
54+
test-against-jax-bindings:
55+
56+
runs-on: ${{ matrix.os }}
57+
continue-on-error: true
58+
59+
strategy:
60+
fail-fast: false
61+
matrix:
62+
os: [ ubuntu-22.04, macos-11, windows-2022 ]
63+
64+
steps:
65+
- name: Checkout
66+
uses: actions/checkout@v2
67+
68+
- name: Setup Python
69+
uses: actions/setup-python@v4
70+
with:
71+
python-version: "3.11"
72+
73+
- name: Install and configure
74+
shell: bash
75+
run: |
76+
pip install .[jax-test] \
77+
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
78+
if [ ${{ matrix.os }} == 'windows-2022' ]; then
79+
# configure-mlir-python-utils.exe -y jaxlib.mlir
80+
pushd /tmp && python -m mlir_utils._configuration -y jaxlib.mlir && popd
81+
else
82+
configure-mlir-python-utils -y jaxlib.mlir
83+
fi
84+
85+
- name: Test
86+
shell: bash
87+
run: |
88+
pytest tests/test_smoke.py

mlir_utils/_configuration/generate_trampolines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def generate_op_trampoline(op_class):
8080
for k, d in zip(args.kwonlyargs, args.kw_defaults)
8181
]
8282

83-
fun_name = op_class.OPERATION_NAME.split(".")[-1]
83+
fun_name = op_class.OPERATION_NAME.split(".")[-1].replace("-", "_")
8484
if keyword.iskeyword(fun_name):
8585
fun_name = fun_name + "_"
8686
op_class_name = op_class.__name__

mlir_utils/dialects/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@
33
/bufferization.py
44
/builtin.py
55
/cf.py
6+
/chlo.py
67
/complex.py
78
/func.py
89
/gpu.py
910
/linalg.py
1011
/math.py
1112
/memref.py
13+
/mhlo.py
1214
/ml_program.py
1315
/pdl.py
1416
/quant.py
1517
/scf.py
1618
/shape.py
1719
/sparse_tensor.py
20+
/stablehlo.py
1821
/tensor.py
1922
/torch.py
2023
/tosa.py

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414

1515
[project.optional-dependencies]
1616
torch-mlir-test = ["pytest", "torch-mlir-core", "mlir-native-tools"]
17+
jax-test = ["pytest", "jax[cpu]", "mlir-native-tools"]
1718
mlir-test = ["pytest", "mlir-python-bindings", "mlir-native-tools"]
1819
mlir = ["mlir-python-bindings"]
1920

tests/test_smoke.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
# noinspection PyUnresolvedReferences
1313
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
14+
from util import skip_jax_not_installed, skip_torch_mlir_not_installed
1415

1516
# needed since the fix isn't defined here nor conftest.py
1617
pytest.mark.usefixtures("ctx")
@@ -26,6 +27,8 @@ def test_smoke(ctx: MLIRContext):
2627
filecheck(correct, ctx.module)
2728

2829

30+
# skip if jax *is* installed because jax doesn't generate almost any of the upstream dialects
31+
@pytest.mark.skipif(lambda: not skip_jax_not_installed(), reason="jax installed")
2932
def test_dialect_trampolines_smoke():
3033
# noinspection PyUnresolvedReferences
3134
from mlir_utils.dialects import (
@@ -51,15 +54,18 @@ def test_dialect_trampolines_smoke():
5154
)
5255

5356

54-
def skip_torch_mlir_not_installed():
55-
try:
56-
from torch_mlir.dialects import torch
57-
58-
# don't skip
59-
return False
60-
except ImportError:
61-
# skip
62-
return True
57+
@pytest.mark.skipif(skip_jax_not_installed(), reason="jax not installed")
58+
def test_dialect_trampolines_smoke():
59+
# noinspection PyUnresolvedReferences
60+
from mlir_utils.dialects import (
61+
builtin,
62+
chlo,
63+
func,
64+
mhlo,
65+
ml_program,
66+
sparse_tensor,
67+
stablehlo,
68+
)
6369

6470

6571
@pytest.mark.skipif(skip_torch_mlir_not_installed(), reason="torch_mlir not installed")
@@ -74,3 +80,26 @@ def test_torch_dialect_trampolines_smoke():
7480
)
7581
# noinspection PyUnresolvedReferences
7682
from mlir_utils.dialects import torch
83+
84+
85+
@pytest.mark.skipif(skip_jax_not_installed(), reason="jax not installed")
86+
def test_jax_trampolines_smoke():
87+
for mod in ["chlo", "mhlo", "stablehlo"]:
88+
try:
89+
modu = __import__(f"mlir_utils.dialects.{mod}", fromlist=["*"])
90+
os.remove(modu.__file__)
91+
except ModuleNotFoundError:
92+
pass
93+
generate_trampolines(
94+
f"jaxlib.mlir.dialects.{mod}", Path(mlir_utils.dialects.__path__[0]), mod
95+
)
96+
# noinspection PyUnresolvedReferences
97+
from mlir_utils.dialects import (
98+
builtin,
99+
chlo,
100+
func,
101+
mhlo,
102+
ml_program,
103+
sparse_tensor,
104+
stablehlo,
105+
)

tests/util.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
def skip_torch_mlir_not_installed():
2+
try:
3+
from torch_mlir.dialects import torch
4+
5+
# don't skip
6+
return False
7+
8+
except ImportError:
9+
# skip
10+
return True
11+
12+
13+
def skip_jax_not_installed():
14+
try:
15+
from jaxlib import mlir
16+
17+
# don't skip
18+
return False
19+
20+
except ImportError:
21+
# skip
22+
return True

0 commit comments

Comments
 (0)