Skip to content

Commit 50ed0e9

Browse files
committed
feat: ARRAY_API_TESTS_MODULE for runtime-defined xp
1 parent f7a74a6 commit 50ed0e9

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

array_api_tests/__init__.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,27 @@
1313
# You can comment the following out and instead import the specific array module
1414
# you want to test, e.g. `import array_api_strict as xp`.
1515
if "ARRAY_API_TESTS_MODULE" in os.environ:
16-
xp_name = os.environ["ARRAY_API_TESTS_MODULE"]
17-
_module, _sub = xp_name, None
18-
if "." in xp_name:
19-
_module, _sub = xp_name.split(".", 1)
20-
xp = import_module(_module)
21-
if _sub:
22-
try:
23-
xp = getattr(xp, _sub)
24-
except AttributeError:
25-
# _sub may be a submodule that needs to be imported. WE can't
26-
# do this in every case because some array modules are not
27-
# submodules that can be imported (like mxnet.nd).
28-
xp = import_module(xp_name)
16+
env_var = os.environ["ARRAY_API_TESTS_MODULE"]
17+
if env_var.startswith("exec(") and env_var.endswith(")"):
18+
script = env_var[5:][:-1]
19+
namespace = {}
20+
exec(script, namespace)
21+
xp = namespace["xp"]
22+
xp_name = xp.__name__
23+
else:
24+
xp_name = os.environ["ARRAY_API_TESTS_MODULE"]
25+
_module, _sub = xp_name, None
26+
if "." in xp_name:
27+
_module, _sub = xp_name.split(".", 1)
28+
xp = import_module(_module)
29+
if _sub:
30+
try:
31+
xp = getattr(xp, _sub)
32+
except AttributeError:
33+
# _sub may be a submodule that needs to be imported. WE can't
34+
# do this in every case because some array modules are not
35+
# submodules that can be imported (like mxnet.nd).
36+
xp = import_module(xp_name)
2937
else:
3038
raise RuntimeError(
3139
"No array module specified - either edit __init__.py or set the "

0 commit comments

Comments
 (0)