Skip to content

Commit 9322549

Browse files
committed
Update and add more test scenaio
1 parent 5883cb4 commit 9322549

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

dpnp/tests/test_manipulation.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,20 @@ def test_basic(self, dtype):
13781378
expected = numpy.trim_zeros(a)
13791379
assert_array_equal(result, expected)
13801380

1381+
@testing.with_requires("numpy>=2.2")
1382+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
1383+
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
1384+
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
1385+
def test_basic_nd(self, dtype, trim, ndim):
1386+
a = numpy.ones((2,) * ndim, dtype=dtype)
1387+
a = numpy.pad(a, (2, 1), mode="constant", constant_values=0)
1388+
ia = dpnp.array(a)
1389+
1390+
for axis in list(range(ndim)) + [None]:
1391+
result = dpnp.trim_zeros(ia, trim=trim, axis=axis)
1392+
expected = numpy.trim_zeros(a, trim=trim, axis=axis)
1393+
assert_array_equal(result, expected)
1394+
13811395
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
13821396
@pytest.mark.parametrize("trim", ["F", "B"])
13831397
def test_trim(self, dtype, trim):
@@ -1398,6 +1412,19 @@ def test_all_zero(self, dtype, trim):
13981412
expected = numpy.trim_zeros(a, trim)
13991413
assert_array_equal(result, expected)
14001414

1415+
@testing.with_requires("numpy>=2.2")
1416+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
1417+
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
1418+
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
1419+
def test_all_zero_nd(self, dtype, trim, ndim):
1420+
a = numpy.zeros((3,) * ndim, dtype=dtype)
1421+
ia = dpnp.array(a)
1422+
1423+
for axis in list(range(ndim)) + [None]:
1424+
result = dpnp.trim_zeros(ia, trim=trim, axis=axis)
1425+
expected = numpy.trim_zeros(a, trim=trim, axis=axis)
1426+
assert_array_equal(result, expected)
1427+
14011428
def test_size_zero(self):
14021429
a = numpy.zeros(0)
14031430
ia = dpnp.array(a)
@@ -1416,17 +1443,11 @@ def test_overflow(self, a):
14161443
expected = numpy.trim_zeros(a)
14171444
assert_array_equal(result, expected)
14181445

1419-
# TODO: modify once SAT-7616
1420-
# numpy 2.2 validates trim rules
1421-
@testing.with_requires("numpy<2.2")
1422-
def test_trim_no_rule(self):
1423-
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0])
1424-
ia = dpnp.array(a)
1425-
trim = "ADE" # no "F" or "B" in trim string
1426-
1427-
result = dpnp.trim_zeros(ia, trim)
1428-
expected = numpy.trim_zeros(a, trim)
1429-
assert_array_equal(result, expected)
1446+
@testing.with_requires("numpy>=2.2")
1447+
@pytest.mark.parametrize("xp", [numpy, dpnp])
1448+
def test_trim_no_fb_in_rule(self, xp):
1449+
a = xp.array([0, 0, 1, 0, 2, 3, 4, 0])
1450+
assert_raises(ValueError, xp.trim_zeros, a, "ADE")
14301451

14311452
def test_list_array(self):
14321453
assert_raises(TypeError, dpnp.trim_zeros, [0, 0, 1, 0, 2, 3, 4, 0])

0 commit comments

Comments
 (0)