@@ -1339,18 +1339,24 @@ def test_error(self):
13391339
13401340class TestCompress :
13411341 def test_compress_basic (self ):
1342+ conditions = [True , False , True ]
1343+ a_np = numpy .arange (16 ).reshape (4 , 4 )
13421344 a = dpnp .arange (16 ).reshape (4 , 4 )
1343- condition = dpnp .asarray ([True , False , True ])
1344- r = dpnp .compress (condition , a , axis = 0 )
1345- assert_array_equal (r [0 ], a [0 ])
1346- assert_array_equal (r [1 ], a [2 ])
1345+ cond_np = numpy .array (conditions )
1346+ cond = dpnp .array (conditions )
1347+ expected = numpy .compress (cond_np , a_np , axis = 0 )
1348+ result = dpnp .compress (cond , a , axis = 0 )
1349+ assert_array_equal (expected , result )
13471350
13481351 @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
13491352 def test_compress_condition_all_dtypes (self , dtype ):
1353+ a_np = numpy .arange (10 , dtype = "i4" )
13501354 a = dpnp .arange (10 , dtype = "i4" )
1351- condition = dpnp .tile (dpnp .asarray ([0 , 1 ], dtype = dtype ), 5 )
1352- r = dpnp .compress (condition , a )
1353- assert_array_equal (r , a [1 ::2 ])
1355+ cond_np = numpy .tile (numpy .asarray ([0 , 1 ], dtype = dtype ), 5 )
1356+ cond = dpnp .tile (dpnp .asarray ([0 , 1 ], dtype = dtype ), 5 )
1357+ expected = numpy .compress (cond_np , a_np )
1358+ result = dpnp .compress (cond , a )
1359+ assert_array_equal (expected , result )
13541360
13551361 def test_compress_invalid_out_errors (self ):
13561362 q1 = dpctl .SyclQueue ()
0 commit comments