@@ -652,16 +652,9 @@ def test_values(self, arr_dt, idx_dt, ndim, values):
652652 ia , iind = dpnp .array (a ), dpnp .array (ind )
653653
654654 for axis in range (ndim ):
655- if ndim != 1 and numpy .issubdtype (idx_dt , numpy .uint64 ):
656- # For this special case, dpnp raises an error but NumPy works
657- # TODO: remove the workaround when dpctl-1936 is fixed
658- assert_raises (
659- ValueError , dpnp .put_along_axis , ia , iind , values , axis
660- )
661- else :
662- numpy .put_along_axis (a , ind , values , axis )
663- dpnp .put_along_axis (ia , iind , values , axis )
664- assert_array_equal (ia , a )
655+ numpy .put_along_axis (a , ind , values , axis )
656+ dpnp .put_along_axis (ia , iind , values , axis )
657+ assert_array_equal (ia , a )
665658
666659 @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
667660 @pytest .mark .parametrize ("dt" , [bool , numpy .float32 ])
@@ -678,10 +671,6 @@ def test_broadcast(self, arr_dt, idx_dt):
678671 ia , iind = dpnp .array (a ), dpnp .array (ind )
679672
680673 if numpy .issubdtype (idx_dt , numpy .uint64 ):
681- # For this special case, dpnp raises an error but NumPy works
682- # TODO: remove the workaround when dpctl-1936 is fixed
683- assert_raises (ValueError , dpnp .put_along_axis , ia , iind , 20 , axis = 1 )
684- else :
685674 numpy .put_along_axis (a , ind , 20 , axis = 1 )
686675 dpnp .put_along_axis (ia , iind , 20 , axis = 1 )
687676 assert_array_equal (ia , a )
@@ -732,10 +721,11 @@ def test_1d(self, a_dt, ind_dt, indices, mode):
732721 expected = numpy .take (a , ind , mode = mode )
733722 assert_array_equal (result , expected )
734723 elif numpy .issubdtype (ind_dt , numpy .uint64 ):
735- # For this special case, although casting `ind_dt`` to numpy.intp
736- # is not safe, NumPy and dpnp do not raise an error
724+ # For this special case, although casting `ind_dt` to numpy.intp
725+ # is not safe, dpnp do not raise an error
726+ # NumPy only raises an error on Windows
737727 result = dpnp .take (ia , iind , mode = mode )
738- expected = numpy .take (a , ind , mode = mode )
728+ expected = numpy .take (a , ind . astype ( numpy . int64 ) , mode = mode )
739729 assert_array_equal (result , expected )
740730 else :
741731 assert_raises (TypeError , ia .take , iind , mode = mode )
@@ -758,9 +748,15 @@ def test_2d(self, a_dt, ind_dt, indices, mode, axis):
758748 ind = numpy .array (indices , dtype = ind_dt )
759749 ia , iind = dpnp .array (a ), dpnp .array (ind )
760750
761- result = ia .take (iind , axis = axis , mode = mode )
762- expected = a .take (ind , axis = axis , mode = mode )
763- assert_array_equal (result , expected )
751+ if numpy .issubdtype (ind_dt , numpy .uint64 ):
752+ # For this special case, NumPy raises an error on Windows
753+ result = ia .take (iind , axis = axis , mode = mode )
754+ expected = a .take (ind .astype (numpy .int64 ), axis = axis , mode = mode )
755+ assert_array_equal (result , expected )
756+ else :
757+ result = ia .take (iind , axis = axis , mode = mode )
758+ expected = a .take (ind , axis = axis , mode = mode )
759+ assert_array_equal (result , expected )
764760
765761 @pytest .mark .parametrize ("a_dt" , get_all_dtypes (no_none = True ))
766762 @pytest .mark .parametrize ("mode" , ["clip" , "wrap" ])
@@ -852,14 +848,9 @@ def test_multi_dimensions(self, arr_dt, idx_dt, ndim):
852848 ia , iind = dpnp .array (a ), dpnp .array (ind )
853849
854850 for axis in range (ndim ):
855- if ndim != 1 and numpy .issubdtype (idx_dt , numpy .uint64 ):
856- # For this special case, dpnp raises an error but NumPy works
857- # TODO: remove the workaround when dpctl-1936 is fixed
858- assert_raises (ValueError , dpnp .take_along_axis , ia , iind , axis )
859- else :
860- result = dpnp .take_along_axis (ia , iind , axis )
861- expected = numpy .take_along_axis (a , ind , axis )
862- assert_array_equal (expected , result )
851+ result = dpnp .take_along_axis (ia , iind , axis )
852+ expected = numpy .take_along_axis (a , ind , axis )
853+ assert_array_equal (expected , result )
863854
864855 @pytest .mark .parametrize ("xp" , [numpy , dpnp ])
865856 def test_not_enough_indices (self , xp ):
@@ -892,14 +883,9 @@ def test_empty(self, a_dt, idx_dt):
892883 ind = numpy .ones ((3 , 0 , 5 ), dtype = idx_dt )
893884 ia , iind = dpnp .array (a ), dpnp .array (ind )
894885
895- if numpy .issubdtype (idx_dt , numpy .uint64 ):
896- # For this special case, dpnp raises an error but NumPy works
897- # TODO: remove the workaround when dpctl-1936 is fixed
898- assert_raises (ValueError , dpnp .take_along_axis , ia , iind , axis = 1 )
899- else :
900- result = dpnp .take_along_axis (ia , iind , axis = 1 )
901- expected = numpy .take_along_axis (a , ind , axis = 1 )
902- assert_array_equal (expected , result )
886+ result = dpnp .take_along_axis (ia , iind , axis = 1 )
887+ expected = numpy .take_along_axis (a , ind , axis = 1 )
888+ assert_array_equal (expected , result )
903889
904890 @pytest .mark .parametrize ("a_dt" , get_all_dtypes (no_none = True ))
905891 @pytest .mark .parametrize ("idx_dt" , get_integer_dtypes ())
@@ -908,14 +894,9 @@ def test_broadcast(self, a_dt, idx_dt):
908894 ind = numpy .ones ((1 , 2 , 5 ), dtype = idx_dt )
909895 ia , iind = dpnp .array (a ), dpnp .array (ind )
910896
911- if numpy .issubdtype (idx_dt , numpy .uint64 ):
912- # For this special case, dpnp raises an error but NumPy works
913- # TODO: remove the workaround when dpctl-1936 is fixed
914- assert_raises (ValueError , dpnp .take_along_axis , ia , iind , axis = 1 )
915- else :
916- result = dpnp .take_along_axis (ia , iind , axis = 1 )
917- expected = numpy .take_along_axis (a , ind , axis = 1 )
918- assert_array_equal (expected , result )
897+ result = dpnp .take_along_axis (ia , iind , axis = 1 )
898+ expected = numpy .take_along_axis (a , ind , axis = 1 )
899+ assert_array_equal (expected , result )
919900
920901 def test_mode_wrap (self ):
921902 a = numpy .array ([- 2 , - 1 , 0 , 1 , 2 ])
0 commit comments