@@ -212,11 +212,31 @@ def test_xp(self, xp: ModuleType):
212212 assert meta_namespace (* args , xp = xp ) in (xp , np_compat )
213213
214214
215- def test_capabilities (xp : ModuleType ):
216- expect = {"boolean indexing" , "data-dependent shapes" }
217- if xp .__array_api_version__ >= "2024.12" :
218- expect .add ("max dimensions" )
219- assert capabilities (xp ).keys () == expect
215+ class TestCapabilities :
216+ def test_basic (self , xp : ModuleType ):
217+ expect = {"boolean indexing" , "data-dependent shapes" }
218+ if xp .__array_api_version__ >= "2024.12" :
219+ expect .add ("max dimensions" )
220+ assert capabilities (xp ).keys () == expect
221+
222+ def test_device (self , xp : ModuleType , library : Backend , device : Device ):
223+ expect_keys = {"boolean indexing" , "data-dependent shapes" }
224+ if xp .__array_api_version__ >= "2024.12" :
225+ expect_keys .add ("max dimensions" )
226+ assert capabilities (xp , device = device ).keys () == expect_keys
227+
228+ if library .like (Backend .TORCH ):
229+ # The output of capabilities is device-specific.
230+
231+ # Test that device=None gets the current default device.
232+ expect = capabilities (xp , device = device )
233+ with xp .device (device ):
234+ actual = capabilities (xp )
235+ assert actual == expect
236+
237+ # Test that we're accepting anything that is accepted by the
238+ # device= parameter in other functions
239+ actual = capabilities (xp , device = device .type ) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]
220240
221241
222242class Wrapper (Generic [T ]):
0 commit comments