File tree Expand file tree Collapse file tree 2 files changed +11
-14
lines changed Expand file tree Collapse file tree 2 files changed +11
-14
lines changed Original file line number Diff line number Diff line change 1
1
import os
2
2
3
3
from pip ._internal .req import parse_requirements
4
- from setuptools import setup
4
+ from setuptools import setup , find_namespace_packages
5
5
6
6
# TODO: find from extras maybe
7
7
HOST_MLIR_PYTHON_PACKAGE_PREFIX = os .environ .get (
@@ -15,6 +15,12 @@ def load_requirements(fname):
15
15
return [str (ir .requirement ) for ir in reqs ]
16
16
17
17
18
+ packages = [
19
+ f"{ HOST_MLIR_PYTHON_PACKAGE_PREFIX } .extras.{ p } "
20
+ for p in find_namespace_packages (where = "mlir/extras" )
21
+ ] + [f"{ HOST_MLIR_PYTHON_PACKAGE_PREFIX } .extras" ]
22
+
23
+
18
24
setup (
19
25
name = PACKAGE_NAME ,
20
26
version = "0.0.7" ,
@@ -28,6 +34,7 @@ def load_requirements(fname):
28
34
"mlir" : ["mlir-python-bindings" ],
29
35
},
30
36
python_requires = ">=3.8" ,
37
+ packages = packages ,
31
38
# lhs is package namespace, rhs is path (relative to this setup.py)
32
39
package_dir = {
33
40
f"{ HOST_MLIR_PYTHON_PACKAGE_PREFIX } .extras" : "mlir/extras" ,
Original file line number Diff line number Diff line change 6
6
@pytest .mark .skipif (jax_not_installed (), reason = "jax not installed" )
7
7
def test_jax_trampolines_smoke ():
8
8
# noinspection PyUnresolvedReferences
9
- from jaxlib .mlir .dialects import (
10
- arith ,
11
- builtin ,
12
- chlo ,
13
- func ,
14
- math ,
15
- memref ,
16
- mhlo ,
17
- scf ,
18
- sparse_tensor ,
19
- stablehlo ,
20
- vector ,
21
- )
9
+ from jaxlib .mlir .extras import context
10
+ # noinspection PyUnresolvedReferences
11
+ from jaxlib .mlir .extras .runtime import passes
You can’t perform that action at this time.
0 commit comments