Skip to content

Commit d7668eb

Browse files
authored
explicit_package (#82)
1 parent eb0bbca commit d7668eb

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

setup.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
from pip._internal.req import parse_requirements
4-
from setuptools import setup
4+
from setuptools import setup, find_namespace_packages
55

66
# TODO: find from extras maybe
77
HOST_MLIR_PYTHON_PACKAGE_PREFIX = os.environ.get(
@@ -15,6 +15,12 @@ def load_requirements(fname):
1515
return [str(ir.requirement) for ir in reqs]
1616

1717

18+
packages = [
19+
f"{HOST_MLIR_PYTHON_PACKAGE_PREFIX}.extras.{p}"
20+
for p in find_namespace_packages(where="mlir/extras")
21+
] + [f"{HOST_MLIR_PYTHON_PACKAGE_PREFIX}.extras"]
22+
23+
1824
setup(
1925
name=PACKAGE_NAME,
2026
version="0.0.7",
@@ -28,6 +34,7 @@ def load_requirements(fname):
2834
"mlir": ["mlir-python-bindings"],
2935
},
3036
python_requires=">=3.8",
37+
packages=packages,
3138
# lhs is package namespace, rhs is path (relative to this setup.py)
3239
package_dir={
3340
f"{HOST_MLIR_PYTHON_PACKAGE_PREFIX}.extras": "mlir/extras",

tests/test_other_hosts.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,6 @@
66
@pytest.mark.skipif(jax_not_installed(), reason="jax not installed")
77
def test_jax_trampolines_smoke():
88
# noinspection PyUnresolvedReferences
9-
from jaxlib.mlir.dialects import (
10-
arith,
11-
builtin,
12-
chlo,
13-
func,
14-
math,
15-
memref,
16-
mhlo,
17-
scf,
18-
sparse_tensor,
19-
stablehlo,
20-
vector,
21-
)
9+
from jaxlib.mlir.extras import context
10+
# noinspection PyUnresolvedReferences
11+
from jaxlib.mlir.extras.runtime import passes

0 commit comments

Comments
 (0)