@@ -1864,14 +1864,13 @@ def test_take_along_axis_uint64_indices():
18641864 get_queue_or_skip ()
18651865
18661866 inds = dpt .arange (1 , 10 , 2 , dtype = "u8" )
1867-
18681867 x = dpt .tile (dpt .asarray ([0 , - 1 ], dtype = "i4" ), 5 )
18691868 res = dpt .take_along_axis (x , inds )
18701869 assert dpt .all (res == - 1 )
18711870
1872- x = dpt . tile ( dpt . asarray ([ 0 , - 1 ], dtype = "i4" ), ( 2 , 5 ))
1873- inds = dpt .arange ( 1 , 10 , 2 , dtype = "u8" )
1874- inds = dpt .broadcast_to (inds , (2 , 5 ) )
1871+ sh0 = 2
1872+ inds = dpt .broadcast_to ( inds , ( sh0 ,) + inds . shape )
1873+ x = dpt .broadcast_to (x , (sh0 ,) + x . shape )
18751874 res = dpt .take_along_axis (x , inds , axis = 1 )
18761875 assert dpt .all (res == - 1 )
18771876
@@ -1880,14 +1879,14 @@ def test_put_along_axis_uint64_indices():
18801879 get_queue_or_skip ()
18811880
18821881 inds = dpt .arange (1 , 10 , 2 , dtype = "u8" )
1883-
18841882 x = dpt .zeros (10 , dtype = "i4" )
18851883 dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ))
18861884 expected = dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), 5 )
18871885 assert dpt .all (x == expected )
18881886
1889- x = dpt .zeros ((2 , 10 ), dtype = "i4" )
1890- inds = dpt .broadcast_to (inds , (2 , 5 ))
1887+ sh0 = 2
1888+ inds = dpt .broadcast_to (inds , (sh0 ,) + inds .shape )
1889+ x = dpt .zeros ((sh0 ,) + x .shape , dtype = "i4" )
18911890 dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ), axis = 1 )
18921891 expected = dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), (2 , 5 ))
18931892 assert dpt .all (expected == x )
0 commit comments