@@ -249,6 +249,59 @@ def test_dir(library, module):
249
249
assert not fails , "Missing exports: %s" % fails
250
250
251
251
252
+ @pytest .mark .parametrize ("module" , list (NAMES ))
253
+ @pytest .mark .parametrize ("library" , wrapped_libraries )
254
+ def test_compat_doesnt_hide_names (library , module ):
255
+ """The base namespace can have more names than the ones explicitly exported
256
+ by array-api-compat. Test that we're not suppressing them.
257
+ """
258
+ bare_xp = pytest .importorskip (library )
259
+ compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
260
+ bare_mod = getattr (bare_xp , module ) if module else bare_xp
261
+ compat_mod = getattr (compat_xp , module ) if module else compat_xp
262
+ aapi_names = set (NAMES [module ])
263
+ extra_names = {
264
+ name
265
+ for name in dir (bare_mod )
266
+ if not name .startswith ("_" ) and name not in aapi_names
267
+ }
268
+ missing = extra_names - set (dir (compat_mod ))
269
+
270
+ # These are spurious to begin with in the bare libraries
271
+ missing -= {"annotations" , "importlib" , "warnings" , "operator" , "sys" , "Sequence" }
272
+ if module != "" :
273
+ missing -= {"Array" , "test" }
274
+
275
+ assert not missing , "Non-Array API names have been hidden: %s" % missing
276
+
277
+
278
+ @pytest .mark .parametrize ("module" , list (NAMES ))
279
+ @pytest .mark .parametrize ("library" , wrapped_libraries )
280
+ def test_compat_spurious_names (library , module ):
281
+ """Test that array-api-compat isn't adding non-Array API names
282
+ to the namespace.
283
+ """
284
+ bare_xp = pytest .importorskip (library )
285
+ compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
286
+ bare_mod = getattr (bare_xp , module ) if module else bare_xp
287
+ compat_mod = getattr (compat_xp , module ) if module else compat_xp
288
+ aapi_names = set (NAMES [module ])
289
+ compat_spurious_names = (
290
+ set (dir (compat_mod ))
291
+ - set (dir (bare_mod ))
292
+ - aapi_names
293
+ - {"__all__" }
294
+ )
295
+ # Quietly ignore *Result dataclasses
296
+ compat_spurious_names = {
297
+ name for name in compat_spurious_names if not name .endswith ("Result" )
298
+ }
299
+
300
+ assert not compat_spurious_names , (
301
+ "array-api-compat is adding non-Array API names: %s" % compat_spurious_names
302
+ )
303
+
304
+
252
305
@pytest .mark .parametrize (
253
306
"name" , [name for name in NAMES ["" ] if hasattr (builtins , name )]
254
307
)
0 commit comments