Skip to content

Commit 724ebd5

Browse files
committed
Merge branch 'test_spurious_names' into test_all
2 parents e57b9ef + 978053a commit 724ebd5

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

tests/test_all.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,59 @@ def test_all(library, module):
274274
assert not fails, "Missing exports: %s" % fails
275275

276276

277+
@pytest.mark.parametrize("module", list(NAMES))
278+
@pytest.mark.parametrize("library", wrapped_libraries)
279+
def test_compat_doesnt_hide_names(library, module):
280+
"""The base namespace can have more names than the ones explicitly exported
281+
by array-api-compat. Test that we're not suppressing them.
282+
"""
283+
bare_xp = pytest.importorskip(library)
284+
compat_xp = pytest.importorskip(f"array_api_compat.{library}")
285+
bare_mod = getattr(bare_xp, module) if module else bare_xp
286+
compat_mod = getattr(compat_xp, module) if module else compat_xp
287+
aapi_names = set(NAMES[module])
288+
extra_names = {
289+
name
290+
for name in dir(bare_mod)
291+
if not name.startswith("_") and name not in aapi_names
292+
}
293+
missing = extra_names - set(dir(compat_mod))
294+
295+
# These are spurious to begin with in the bare libraries
296+
missing -= {"annotations", "importlib", "warnings", "operator", "sys", "Sequence"}
297+
if module != "":
298+
missing -= {"Array", "test"}
299+
300+
assert not missing, "Non-Array API names have been hidden: %s" % missing
301+
302+
303+
@pytest.mark.parametrize("module", list(NAMES))
304+
@pytest.mark.parametrize("library", wrapped_libraries)
305+
def test_compat_spurious_names(library, module):
306+
"""Test that array-api-compat isn't adding non-Array API names
307+
to the namespace.
308+
"""
309+
bare_xp = pytest.importorskip(library)
310+
compat_xp = pytest.importorskip(f"array_api_compat.{library}")
311+
bare_mod = getattr(bare_xp, module) if module else bare_xp
312+
compat_mod = getattr(compat_xp, module) if module else compat_xp
313+
aapi_names = set(NAMES[module])
314+
compat_spurious_names = (
315+
set(dir(compat_mod))
316+
- set(dir(bare_mod))
317+
- aapi_names
318+
- {"__all__"}
319+
)
320+
# Quietly ignore *Result dataclasses
321+
compat_spurious_names = {
322+
name for name in compat_spurious_names if not name.endswith("Result")
323+
}
324+
325+
assert not compat_spurious_names, (
326+
"array-api-compat is adding non-Array API names: %s" % compat_spurious_names
327+
)
328+
329+
277330
@pytest.mark.parametrize(
278331
"name", [name for name in NAMES[""] if hasattr(builtins, name)]
279332
)

0 commit comments

Comments
 (0)