Skip to content

Commit 978053a

Browse files
committed
WIP test vs extra names
1 parent 2f01c20 commit 978053a

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

tests/test_all.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,59 @@ def test_dir(library, module):
249249
assert not fails, "Missing exports: %s" % fails
250250

251251

252+
@pytest.mark.parametrize("module", list(NAMES))
253+
@pytest.mark.parametrize("library", wrapped_libraries)
254+
def test_compat_doesnt_hide_names(library, module):
255+
"""The base namespace can have more names than the ones explicitly exported
256+
by array-api-compat. Test that we're not suppressing them.
257+
"""
258+
bare_xp = pytest.importorskip(library)
259+
compat_xp = pytest.importorskip(f"array_api_compat.{library}")
260+
bare_mod = getattr(bare_xp, module) if module else bare_xp
261+
compat_mod = getattr(compat_xp, module) if module else compat_xp
262+
aapi_names = set(NAMES[module])
263+
extra_names = {
264+
name
265+
for name in dir(bare_mod)
266+
if not name.startswith("_") and name not in aapi_names
267+
}
268+
missing = extra_names - set(dir(compat_mod))
269+
270+
# These are spurious to begin with in the bare libraries
271+
missing -= {"annotations", "importlib", "warnings", "operator", "sys", "Sequence"}
272+
if module != "":
273+
missing -= {"Array", "test"}
274+
275+
assert not missing, "Non-Array API names have been hidden: %s" % missing
276+
277+
278+
@pytest.mark.parametrize("module", list(NAMES))
279+
@pytest.mark.parametrize("library", wrapped_libraries)
280+
def test_compat_spurious_names(library, module):
281+
"""Test that array-api-compat isn't adding non-Array API names
282+
to the namespace.
283+
"""
284+
bare_xp = pytest.importorskip(library)
285+
compat_xp = pytest.importorskip(f"array_api_compat.{library}")
286+
bare_mod = getattr(bare_xp, module) if module else bare_xp
287+
compat_mod = getattr(compat_xp, module) if module else compat_xp
288+
aapi_names = set(NAMES[module])
289+
compat_spurious_names = (
290+
set(dir(compat_mod))
291+
- set(dir(bare_mod))
292+
- aapi_names
293+
- {"__all__"}
294+
)
295+
# Quietly ignore *Result dataclasses
296+
compat_spurious_names = {
297+
name for name in compat_spurious_names if not name.endswith("Result")
298+
}
299+
300+
assert not compat_spurious_names, (
301+
"array-api-compat is adding non-Array API names: %s" % compat_spurious_names
302+
)
303+
304+
252305
@pytest.mark.parametrize(
253306
"name", [name for name in NAMES[""] if hasattr(builtins, name)]
254307
)

0 commit comments

Comments
 (0)