@@ -1378,6 +1378,20 @@ def test_basic(self, dtype):
13781378 expected = numpy .trim_zeros (a )
13791379 assert_array_equal (result , expected )
13801380
1381+ @testing .with_requires ("numpy>=2.2" )
1382+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
1383+ @pytest .mark .parametrize ("trim" , ["F" , "B" , "fb" ])
1384+ @pytest .mark .parametrize ("ndim" , [0 , 1 , 2 , 3 ])
1385+ def test_basic_nd (self , dtype , trim , ndim ):
1386+ a = numpy .ones ((2 ,) * ndim , dtype = dtype )
1387+ a = numpy .pad (a , (2 , 1 ), mode = "constant" , constant_values = 0 )
1388+ ia = dpnp .array (a )
1389+
1390+ for axis in list (range (ndim )) + [None ]:
1391+ result = dpnp .trim_zeros (ia , trim = trim , axis = axis )
1392+ expected = numpy .trim_zeros (a , trim = trim , axis = axis )
1393+ assert_array_equal (result , expected )
1394+
13811395 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
13821396 @pytest .mark .parametrize ("trim" , ["F" , "B" ])
13831397 def test_trim (self , dtype , trim ):
@@ -1398,6 +1412,19 @@ def test_all_zero(self, dtype, trim):
13981412 expected = numpy .trim_zeros (a , trim )
13991413 assert_array_equal (result , expected )
14001414
1415+ @testing .with_requires ("numpy>=2.2" )
1416+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
1417+ @pytest .mark .parametrize ("trim" , ["F" , "B" , "fb" ])
1418+ @pytest .mark .parametrize ("ndim" , [0 , 1 , 2 , 3 ])
1419+ def test_all_zero_nd (self , dtype , trim , ndim ):
1420+ a = numpy .zeros ((3 ,) * ndim , dtype = dtype )
1421+ ia = dpnp .array (a )
1422+
1423+ for axis in list (range (ndim )) + [None ]:
1424+ result = dpnp .trim_zeros (ia , trim = trim , axis = axis )
1425+ expected = numpy .trim_zeros (a , trim = trim , axis = axis )
1426+ assert_array_equal (result , expected )
1427+
14011428 def test_size_zero (self ):
14021429 a = numpy .zeros (0 )
14031430 ia = dpnp .array (a )
@@ -1416,17 +1443,11 @@ def test_overflow(self, a):
14161443 expected = numpy .trim_zeros (a )
14171444 assert_array_equal (result , expected )
14181445
1419- # TODO: modify once SAT-7616
1420- # numpy 2.2 validates trim rules
1421- @testing .with_requires ("numpy<2.2" )
1422- def test_trim_no_rule (self ):
1423- a = numpy .array ([0 , 0 , 1 , 0 , 2 , 3 , 4 , 0 ])
1424- ia = dpnp .array (a )
1425- trim = "ADE" # no "F" or "B" in trim string
1426-
1427- result = dpnp .trim_zeros (ia , trim )
1428- expected = numpy .trim_zeros (a , trim )
1429- assert_array_equal (result , expected )
1446+ @testing .with_requires ("numpy>=2.2" )
1447+ @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
1448+ def test_trim_no_fb_in_rule (self , xp ):
1449+ a = xp .array ([0 , 0 , 1 , 0 , 2 , 3 , 4 , 0 ])
1450+ assert_raises (ValueError , xp .trim_zeros , a , "ADE" )
14301451
14311452 def test_list_array (self ):
14321453 assert_raises (TypeError , dpnp .trim_zeros , [0 , 0 , 1 , 0 , 2 , 3 , 4 , 0 ])
0 commit comments