@@ -879,21 +879,6 @@ def test_mode_clip(self):
879879 assert (result == dpnp .array ([- 2 , 0 , - 2 , 2 ])).all ()
880880
881881
882- def test_choose ():
883- a = numpy .r_ [:4 ]
884- ia = dpnp .array (a )
885- b = numpy .r_ [- 4 :0 ]
886- ib = dpnp .array (b )
887- c = numpy .r_ [100 :500 :100 ]
888- ic = dpnp .array (c )
889-
890- inds_np = numpy .zeros (4 , dtype = "i4" )
891- inds = dpnp .zeros (4 , dtype = "i4" )
892- expected = numpy .choose (inds_np , [a , b , c ])
893- result = dpnp .choose (inds , [ia , ib , ic ])
894- assert_array_equal (expected , result )
895-
896-
897882@pytest .mark .parametrize ("val" , [- 1 , 0 , 1 ], ids = ["-1" , "0" , "1" ])
898883@pytest .mark .parametrize (
899884 "array" ,
@@ -1448,3 +1433,76 @@ def test_compress_strided(self):
14481433 result = dpnp .compress (cond , a )
14491434 expected = numpy .compress (cond_np , a_np )
14501435 assert_array_equal (result , expected )
1436+
1437+
1438+ class TestChoose :
1439+ def test_choose_basic (self ):
1440+ indices = [0 , 1 , 0 ]
1441+ # use a single array for choices
1442+ chcs_np = numpy .arange (2 * len (indices ))
1443+ chcs = dpnp .arange (2 * len (indices ))
1444+ inds_np = numpy .array (indices )
1445+ inds = dpnp .array (indices )
1446+ expected = numpy .choose (inds_np , chcs_np )
1447+ result = dpnp .choose (inds , chcs )
1448+ assert_array_equal (expected , result )
1449+
1450+ def test_choose_method_basic (self ):
1451+ indices = [0 , 1 , 2 ]
1452+ # use a single array for choices
1453+ chcs_np = numpy .arange (3 * len (indices ))
1454+ chcs = dpnp .arange (3 * len (indices ))
1455+ inds_np = numpy .array (indices )
1456+ inds = dpnp .array (indices )
1457+ expected = inds_np .choose (chcs_np )
1458+ result = inds .choose (chcs )
1459+ assert_array_equal (expected , result )
1460+
1461+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
1462+ def test_choose_inds_all_dtypes (self , dtype ):
1463+ if not dpnp .issubdtype (dtype , dpnp .integer ) and dtype != dpnp .bool :
1464+ inds = dpnp .zeros (1 , dtype = dtype )
1465+ chcs = dpnp .ones (1 , dtype = dtype )
1466+ with pytest .raises (TypeError ):
1467+ dpnp .choose (inds , chcs )
1468+ else :
1469+ inds_np = numpy .array ([1 , 0 , 1 ], dtype = dtype )
1470+ inds = dpnp .array (inds_np )
1471+ chcs_np = numpy .array ([1 , 2 , 3 ], dtype = dtype )
1472+ chcs = dpnp .array (chcs_np )
1473+ expected = numpy .choose (inds_np , chcs_np )
1474+ result = dpnp .choose (inds , chcs )
1475+ assert_array_equal (expected , result )
1476+
1477+ def test_choose_invalid_out_errors (self ):
1478+ q1 = dpctl .SyclQueue ()
1479+ q2 = dpctl .SyclQueue ()
1480+ chcs = dpnp .ones (10 , dtype = "i4" , sycl_queue = q1 )
1481+ inds = dpnp .zeros (10 , dtype = "i4" , sycl_queue = q1 )
1482+ out_bad_shape = dpnp .empty (11 , dtype = chcs .dtype , sycl_queue = q1 )
1483+ with pytest .raises (ValueError ):
1484+ dpnp .choose (inds , [chcs ], out = out_bad_shape )
1485+ out_bad_queue = dpnp .empty (chcs .shape , dtype = chcs .dtype , sycl_queue = q2 )
1486+ with pytest .raises (ExecutionPlacementError ):
1487+ dpnp .choose (inds , [chcs ], out = out_bad_queue )
1488+ out_bad_dt = dpnp .empty (chcs .shape , dtype = "i8" , sycl_queue = q1 )
1489+ with pytest .raises (TypeError ):
1490+ dpnp .choose (inds , [chcs ], out = out_bad_dt )
1491+ out_read_only = dpnp .empty (chcs .shape , dtype = chcs .dtype , sycl_queue = q1 )
1492+ out_read_only .flags .writable = False
1493+ with pytest .raises (ValueError ):
1494+ dpnp .choose (inds , [chcs ], out = out_read_only )
1495+
1496+ def test_choose_empty (self ):
1497+ sh = (10 , 0 , 5 )
1498+ inds = dpnp .ones (sh , dtype = "i4" )
1499+ chcs = dpnp .ones (sh )
1500+ r = dpnp .choose (inds , chcs )
1501+ assert r .shape == sh
1502+ r = dpnp .choose (inds , (chcs ,) * 2 )
1503+ assert r .shape == sh
1504+ inds = dpnp .unstack (inds )[0 ]
1505+ r = dpnp .choose (inds , chcs )
1506+ assert r .shape == sh [1 :]
1507+ r = dpnp .choose (inds , [chcs ])
1508+ assert r .shape == sh
0 commit comments