@@ -85,6 +85,11 @@ def test_5D(self):
8585 y = atleast_nd (x , ndim = 9 , xp = xp )
8686 assert_array_equal (y , xp .ones ((1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 )))
8787
88+ def test_device (self ):
89+ device = xp .Device ("device1" )
90+ x = xp .asarray ([1 , 2 , 3 ], device = device )
91+ assert atleast_nd (x , ndim = 2 , xp = xp ).device == device
92+
8893
8994class TestCov :
9095 def test_basic (self ):
@@ -120,6 +125,11 @@ def test_combination(self):
120125 assert_allclose (cov (x , xp = xp ), xp .asarray (11.71 ))
121126 assert_allclose (cov (y , xp = xp ), xp .asarray (2.144133 ), rtol = 1e-6 )
122127
128+ def test_device (self ):
129+ device = xp .Device ("device1" )
130+ x = xp .asarray ([1 , 2 , 3 ], device = device )
131+ assert cov (x , xp = xp ).device == device
132+
123133
124134class TestCreateDiagonal :
125135 def test_1d (self ):
@@ -156,6 +166,11 @@ def test_2d(self):
156166 with pytest .raises (ValueError , match = "1-dimensional" ):
157167 create_diagonal (xp .asarray ([[1 ]]), xp = xp )
158168
169+ def test_device (self ):
170+ device = xp .Device ("device1" )
171+ x = xp .asarray ([1 , 2 , 3 ], device = device )
172+ assert create_diagonal (x , xp = xp ).device == device
173+
159174
160175class TestExpandDims :
161176 def test_functionality (self ):
@@ -205,6 +220,11 @@ def test_positive_negative_repeated(self):
205220 with pytest .raises (ValueError , match = "Duplicate dimensions" ):
206221 expand_dims (a , axis = (3 , - 3 ), xp = xp )
207222
223+ def test_device (self ):
224+ device = xp .Device ("device1" )
225+ x = xp .asarray ([1 , 2 , 3 ], device = device )
226+ assert expand_dims (x , axis = 0 , xp = xp ).device == device
227+
208228
209229class TestKron :
210230 def test_basic (self ):
@@ -270,6 +290,12 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
270290 k = kron (a , b , xp = xp )
271291 assert_equal (k .shape , expected_shape , err_msg = "Unexpected shape from kron" )
272292
293+ def test_device (self ):
294+ device = xp .Device ("device1" )
295+ x1 = xp .asarray ([1 , 2 , 3 ], device = device )
296+ x2 = xp .asarray ([4 , 5 ], device = device )
297+ assert kron (x1 , x2 , xp = xp ).device == device
298+
273299
274300class TestSetDiff1D :
275301 def test_setdiff1d (self ):
@@ -298,6 +324,12 @@ def test_assume_unique(self):
298324 actual = setdiff1d (x1 , x2 , assume_unique = True , xp = xp )
299325 assert_array_equal (actual , expected )
300326
327+ def test_device (self ):
328+ device = xp .Device ("device1" )
329+ x1 = xp .asarray ([3 , 8 , 20 ], device = device )
330+ x2 = xp .asarray ([2 , 3 , 4 ], device = device )
331+ assert setdiff1d (x1 , x2 , xp = xp ).device == device
332+
301333
302334class TestSinc :
303335 def test_simple (self ):
@@ -316,3 +348,8 @@ def test_3d(self):
316348 expected = xp .zeros ((3 , 3 , 2 ))
317349 expected [0 , 0 , 0 ] = 1.0
318350 assert_allclose (sinc (x , xp = xp ), expected , atol = 1e-15 )
351+
352+ def test_device (self ):
353+ device = xp .Device ("device1" )
354+ x = xp .asarray (0.0 , device = device )
355+ assert sinc (x , xp = xp ).device == device
0 commit comments