Skip to content

Commit ecbefdb

Browse files
committed
Fix array tests
1 parent 0628e1f commit ecbefdb

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

src/pydvl/utils/array.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -949,13 +949,9 @@ def array_nonzero(
949949
"""
950950
if is_tensor(x):
951951
assert torch is not None
952-
tensor_array = cast(Tensor, x)
953-
# torch.nonzero returns a tensor of indices
954-
indices = torch.nonzero(tensor_array, as_tuple=True)
955-
return cast(tuple[NDArray, ...], tuple(t.cpu().numpy() for t in indices))
956-
else: # Fallback to numpy approach
957-
numpy_array = to_numpy(x)
958-
return cast(tuple[NDArray, ...], np.nonzero(numpy_array))
952+
nz = torch.nonzero(cast(Tensor, x), as_tuple=True)
953+
return cast(tuple[NDArray, ...], tuple(t.cpu().numpy() for t in nz))
954+
return cast(tuple[NDArray, ...], np.nonzero(to_numpy(x)))
959955

960956

961957
def is_categorical(x: Array[Any]) -> bool:

tests/utils/test_array.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def test_array_unique_torch():
260260
result, indices = array_unique(arr, return_index=True)
261261
assert is_tensor(result)
262262
assert torch.equal(result, torch.tensor([1, 2, 3]))
263-
assert torch.equal(indices, torch.tensor([0, 1, 3]))
263+
np.testing.assert_array_equal(indices, np.array([0, 1, 3]))
264264

265265

266266
def test_array_concatenate():
@@ -467,18 +467,18 @@ def test_array_nonzero_torch():
467467
result = array_nonzero(array)
468468
assert isinstance(result, tuple)
469469
assert len(result) == 1
470-
assert is_tensor(result[0])
471-
assert torch.equal(result[0], torch.tensor([1, 3]))
470+
assert is_numpy(result[0])
471+
np.testing.assert_array_equal(result[0], np.array([1, 3]))
472472

473473
# 2D tensor
474474
array_2d = torch.tensor([[0, 1], [2, 0]])
475475
result = array_nonzero(array_2d)
476476
assert isinstance(result, tuple)
477477
assert len(result) == 2
478-
assert is_tensor(result[0])
479-
assert is_tensor(result[1])
480-
assert torch.equal(result[0], torch.tensor([0, 1]))
481-
assert torch.equal(result[1], torch.tensor([1, 0]))
478+
assert is_numpy(result[0])
479+
assert is_numpy(result[1])
480+
np.testing.assert_array_equal(result[0], np.array([0, 1]))
481+
np.testing.assert_array_equal(result[1], np.array([1, 0]))
482482

483483

484484
def test_stratified_split_indices():

0 commit comments

Comments
 (0)