@@ -274,6 +274,59 @@ def test_all(library, module):
274
274
assert not fails , "Missing exports: %s" % fails
275
275
276
276
277
+ @pytest .mark .parametrize ("module" , list (NAMES ))
278
+ @pytest .mark .parametrize ("library" , wrapped_libraries )
279
+ def test_compat_doesnt_hide_names (library , module ):
280
+ """The base namespace can have more names than the ones explicitly exported
281
+ by array-api-compat. Test that we're not suppressing them.
282
+ """
283
+ bare_xp = pytest .importorskip (library )
284
+ compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
285
+ bare_mod = getattr (bare_xp , module ) if module else bare_xp
286
+ compat_mod = getattr (compat_xp , module ) if module else compat_xp
287
+ aapi_names = set (NAMES [module ])
288
+ extra_names = {
289
+ name
290
+ for name in dir (bare_mod )
291
+ if not name .startswith ("_" ) and name not in aapi_names
292
+ }
293
+ missing = extra_names - set (dir (compat_mod ))
294
+
295
+ # These are spurious to begin with in the bare libraries
296
+ missing -= {"annotations" , "importlib" , "warnings" , "operator" , "sys" , "Sequence" }
297
+ if module != "" :
298
+ missing -= {"Array" , "test" }
299
+
300
+ assert not missing , "Non-Array API names have been hidden: %s" % missing
301
+
302
+
303
+ @pytest .mark .parametrize ("module" , list (NAMES ))
304
+ @pytest .mark .parametrize ("library" , wrapped_libraries )
305
+ def test_compat_spurious_names (library , module ):
306
+ """Test that array-api-compat isn't adding non-Array API names
307
+ to the namespace.
308
+ """
309
+ bare_xp = pytest .importorskip (library )
310
+ compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
311
+ bare_mod = getattr (bare_xp , module ) if module else bare_xp
312
+ compat_mod = getattr (compat_xp , module ) if module else compat_xp
313
+ aapi_names = set (NAMES [module ])
314
+ compat_spurious_names = (
315
+ set (dir (compat_mod ))
316
+ - set (dir (bare_mod ))
317
+ - aapi_names
318
+ - {"__all__" }
319
+ )
320
+ # Quietly ignore *Result dataclasses
321
+ compat_spurious_names = {
322
+ name for name in compat_spurious_names if not name .endswith ("Result" )
323
+ }
324
+
325
+ assert not compat_spurious_names , (
326
+ "array-api-compat is adding non-Array API names: %s" % compat_spurious_names
327
+ )
328
+
329
+
277
330
@pytest .mark .parametrize (
278
331
"name" , [name for name in NAMES ["" ] if hasattr (builtins , name )]
279
332
)
0 commit comments