Skip to content

Commit 9ffc9b3

Browse files
committed
try revert
1 parent c77aca1 commit 9ffc9b3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

sklearn/utils/tests/test_array_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,10 @@ def test_move_to_array_api_conversions(array_input, reference):
156156

157157
with config_context(array_api_dispatch=True):
158158
array_in = xp_from.asarray([1, 2, 3], device=array_input.device)
159-
array_out = move_to(array_in, xp=xp_to, device=reference.device)
159+
device_reference = device(xp_to.asarray([1], device=reference.device))
160+
array_out = move_to(array_in, xp=xp_to, device=device_reference)
160161
assert get_namespace(array_out)[0] == xp_to
161-
assert device(array_out) == device(xp_to.asarray([1], device=reference.device))
162+
assert device(array_out) == device_reference
162163

163164

164165
def test_move_to_sparse():
@@ -167,7 +168,6 @@ def test_move_to_sparse():
167168
xp_torch = _array_api_for_tests("torch", "cpu")
168169

169170
sparse1 = sp.csr_array([0, 1, 2, 3])
170-
sparse2 = sp.csr_array([0, 1, 0, 1])
171171
numpy_array = numpy.array([1, 2, 3])
172172

173173
with config_context(array_api_dispatch=True):

0 commit comments

Comments
 (0)