Skip to content

Commit fadf701

Browse files
committed
TST: setdiff1d: add tests
1 parent f835502 commit fadf701

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

src/array_api_extra/_lib/_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def in1d(
5252
order = xp.argsort(ar, stable=True)
5353
reverse_order = xp.argsort(order, stable=True)
5454
sar = xp.take(ar, order, axis=0)
55-
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
55+
if sar.size >= 1:
56+
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
57+
else:
58+
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
5659
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
5760
ret = xp.take(flag, reverse_order, axis=0)
5861

tests/test_funcs.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
import pytest
1111
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
1212

13-
from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc
13+
from array_api_extra import (
14+
atleast_nd,
15+
cov,
16+
create_diagonal,
17+
expand_dims,
18+
kron,
19+
setdiff1d,
20+
sinc,
21+
)
1422

1523
if typing.TYPE_CHECKING:
1624
from array_api_extra._lib._typing import Array
@@ -263,6 +271,34 @@ def test_positive_negative_repeated(self):
263271
expand_dims(a, axis=(3, -3), xp=xp)
264272

265273

274+
class TestSetDiff1D:
275+
def test_setdiff1d(self):
276+
x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4])
277+
x2 = xp.asarray([2, 4, 3, 3, 2, 1, 5])
278+
279+
expected = xp.asarray([6, 7])
280+
actual = setdiff1d(x1, x2, xp=xp)
281+
assert_array_equal(actual, expected)
282+
283+
x1 = xp.arange(21)
284+
x2 = xp.arange(19)
285+
expected = xp.asarray([19, 20])
286+
actual = setdiff1d(x1, x2, xp=xp)
287+
assert_array_equal(actual, expected)
288+
289+
assert_array_equal(setdiff1d(xp.empty(0), xp.empty(0), xp=xp), xp.empty(0))
290+
x1 = xp.empty(0, dtype=xp.uint32)
291+
x2 = x1
292+
assert_equal(setdiff1d(x1, x2, xp=xp).dtype, xp.uint32)
293+
294+
def test_setdiff1d_unique(self):
295+
x1 = xp.asarray([3, 2, 1])
296+
x2 = xp.asarray([7, 5, 2])
297+
expected = xp.asarray([3, 1])
298+
actual = setdiff1d(x1, x2, assume_unique=True, xp=xp)
299+
assert_array_equal(actual, expected)
300+
301+
266302
class TestSinc:
267303
def test_simple(self):
268304
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))

0 commit comments

Comments
 (0)