@@ -837,11 +837,26 @@ <h1>Source code for dpctl.tensor._sorting</h1><div class="highlight"><pre>
837837 < span class ="n "> _sort_ascending</ span > < span class ="p "> ,</ span >
838838 < span class ="n "> _sort_descending</ span > < span class ="p "> ,</ span >
839839< span class ="p "> )</ span >
840+ < span class ="kn "> from</ span > < span class ="nn "> ._tensor_sorting_radix_impl</ span > < span class ="kn "> import</ span > < span class ="p "> (</ span >
841+ < span class ="n "> _radix_argsort_ascending</ span > < span class ="p "> ,</ span >
842+ < span class ="n "> _radix_argsort_descending</ span > < span class ="p "> ,</ span >
843+ < span class ="n "> _radix_sort_ascending</ span > < span class ="p "> ,</ span >
844+ < span class ="n "> _radix_sort_descending</ span > < span class ="p "> ,</ span >
845+ < span class ="n "> _radix_sort_dtype_supported</ span > < span class ="p "> ,</ span >
846+ < span class ="p "> )</ span >
840847
841848< span class ="n "> __all__</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="s2 "> "sort"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "argsort"</ span > < span class ="p "> ]</ span >
842849
843850
844- < div class ="viewcode-block " id ="sort "> < a class ="viewcode-back " href ="../../../api_reference/dpctl/generated/dpctl.tensor.sort.html#dpctl.tensor.sort "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sort</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="o "> /</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="o "> =-</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> descending</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> stable</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ):</ span >
851+ < span class ="k "> def</ span > < span class ="nf "> _get_mergesort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> ):</ span >
852+ < span class ="k "> return</ span > < span class ="n "> _sort_descending</ span > < span class ="k "> if</ span > < span class ="n "> descending</ span > < span class ="k "> else</ span > < span class ="n "> _sort_ascending</ span >
853+
854+
855+ < span class ="k "> def</ span > < span class ="nf "> _get_radixsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> ):</ span >
856+ < span class ="k "> return</ span > < span class ="n "> _radix_sort_descending</ span > < span class ="k "> if</ span > < span class ="n "> descending</ span > < span class ="k "> else</ span > < span class ="n "> _radix_sort_ascending</ span >
857+
858+
859+ < div class ="viewcode-block " id ="sort "> < a class ="viewcode-back " href ="../../../api_reference/dpctl/generated/dpctl.tensor.sort.html#dpctl.tensor.sort "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sort</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="o "> /</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="o "> =-</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> descending</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> stable</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span > < span class ="n "> kind</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ):</ span >
845860< span class ="w "> </ span > < span class ="sd "> """sort(x, axis=-1, descending=False, stable=True)</ span >
846861
847862< span class ="sd "> Returns a sorted copy of an input array `x`.</ span >
@@ -861,7 +876,10 @@ <h1>Source code for dpctl.tensor._sorting</h1><div class="highlight"><pre>
861876< span class ="sd "> relative order of `x` values which compare as equal. If `False`,</ span >
862877< span class ="sd "> the returned array may or may not maintain the relative order of</ span >
863878< span class ="sd "> `x` values which compare as equal. Default: `True`.</ span >
864-
879+ < span class ="sd "> kind (Optional[Literal["stable", "mergesort", "radixsort"]]):</ span >
880+ < span class ="sd "> Sorting algorithm. The default is `"stable"`, which uses parallel</ span >
881+ < span class ="sd "> merge-sort or parallel radix-sort algorithms depending on the</ span >
882+ < span class ="sd "> array data type.</ span >
865883< span class ="sd "> Returns:</ span >
866884< span class ="sd "> usm_ndarray:</ span >
867885< span class ="sd "> a sorted array. The returned array has the same data type and</ span >
@@ -886,10 +904,33 @@ <h1>Source code for dpctl.tensor._sorting</h1><div class="highlight"><pre>
886904 < span class ="n "> axis</ span > < span class ="p "> ,</ span >
887905 < span class ="p "> ]</ span >
888906 < span class ="n "> arr</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> permute_dims</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> perm</ span > < span class ="p "> )</ span >
907+ < span class ="k "> if</ span > < span class ="n "> kind</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
908+ < span class ="n "> kind</ span > < span class ="o "> =</ span > < span class ="s2 "> "stable"</ span >
909+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> kind</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> )</ span > < span class ="ow "> or</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="p "> [</ span >
910+ < span class ="s2 "> "stable"</ span > < span class ="p "> ,</ span >
911+ < span class ="s2 "> "radixsort"</ span > < span class ="p "> ,</ span >
912+ < span class ="s2 "> "mergesort"</ span > < span class ="p "> ,</ span >
913+ < span class ="p "> ]:</ span >
914+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span >
915+ < span class ="s2 "> "Unsupported kind value. Expected 'stable', 'mergesort', "</ span >
916+ < span class ="sa "> f</ span > < span class ="s2 "> "or 'radixsort', but got '</ span > < span class ="si "> {</ span > < span class ="n "> kind</ span > < span class ="si "> }</ span > < span class ="s2 "> '"</ span >
917+ < span class ="p "> )</ span >
918+ < span class ="k "> if</ span > < span class ="n "> kind</ span > < span class ="o "> ==</ span > < span class ="s2 "> "mergesort"</ span > < span class ="p "> :</ span >
919+ < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_mergesort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
920+ < span class ="k "> elif</ span > < span class ="n "> kind</ span > < span class ="o "> ==</ span > < span class ="s2 "> "radixsort"</ span > < span class ="p "> :</ span >
921+ < span class ="k "> if</ span > < span class ="n "> _radix_sort_dtype_supported</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> .</ span > < span class ="n "> num</ span > < span class ="p "> ):</ span >
922+ < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_radixsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
923+ < span class ="k "> else</ span > < span class ="p "> :</ span >
924+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Radix sort is not supported for </ span > < span class ="si "> {</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
925+ < span class ="k "> else</ span > < span class ="p "> :</ span >
926+ < span class ="n "> dt</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span >
927+ < span class ="k "> if</ span > < span class ="n "> dt</ span > < span class ="ow "> in</ span > < span class ="p "> [</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> uint8</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> int8</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> int16</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> uint16</ span > < span class ="p "> ]:</ span >
928+ < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_radixsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
929+ < span class ="k "> else</ span > < span class ="p "> :</ span >
930+ < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_mergesort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
889931 < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span >
890932 < span class ="n "> _manager</ span > < span class ="o "> =</ span > < span class ="n "> du</ span > < span class ="o "> .</ span > < span class ="n "> SequentialOrderManager</ span > < span class ="p "> [</ span > < span class ="n "> exec_q</ span > < span class ="p "> ]</ span >
891933 < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
892- < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _sort_descending</ span > < span class ="k "> if</ span > < span class ="n "> descending</ span > < span class ="k "> else</ span > < span class ="n "> _sort_ascending</ span >
893934 < span class ="k "> if</ span > < span class ="n "> arr</ span > < span class ="o "> .</ span > < span class ="n "> flags</ span > < span class ="o "> .</ span > < span class ="n "> c_contiguous</ span > < span class ="p "> :</ span >
894935 < span class ="n "> res</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> empty_like</ span > < span class ="p "> (</ span > < span class ="n "> arr</ span > < span class ="p "> ,</ span > < span class ="n "> order</ span > < span class ="o "> =</ span > < span class ="s2 "> "C"</ span > < span class ="p "> )</ span >
895936 < span class ="n "> ht_ev</ span > < span class ="p "> ,</ span > < span class ="n "> impl_ev</ span > < span class ="o "> =</ span > < span class ="n "> impl_fn</ span > < span class ="p "> (</ span >
@@ -921,7 +962,15 @@ <h1>Source code for dpctl.tensor._sorting</h1><div class="highlight"><pre>
921962 < span class ="k "> return</ span > < span class ="n "> res</ span > </ div >
922963
923964
924- < div class ="viewcode-block " id ="argsort "> < a class ="viewcode-back " href ="../../../api_reference/dpctl/generated/dpctl.tensor.argsort.html#dpctl.tensor.argsort "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> argsort</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="o "> =-</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> descending</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> stable</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ):</ span >
965+ < span class ="k "> def</ span > < span class ="nf "> _get_mergeargsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> ):</ span >
966+ < span class ="k "> return</ span > < span class ="n "> _argsort_descending</ span > < span class ="k "> if</ span > < span class ="n "> descending</ span > < span class ="k "> else</ span > < span class ="n "> _argsort_ascending</ span >
967+
968+
969+ < span class ="k "> def</ span > < span class ="nf "> _get_radixargsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> ):</ span >
970+ < span class ="k "> return</ span > < span class ="n "> _radix_argsort_descending</ span > < span class ="k "> if</ span > < span class ="n "> descending</ span > < span class ="k "> else</ span > < span class ="n "> _radix_argsort_ascending</ span >
971+
972+
973+ < div class ="viewcode-block " id ="argsort "> < a class ="viewcode-back " href ="../../../api_reference/dpctl/generated/dpctl.tensor.argsort.html#dpctl.tensor.argsort "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> argsort</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="o "> =-</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> descending</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> stable</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span > < span class ="n "> kind</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ):</ span >
925974< span class ="w "> </ span > < span class ="sd "> """argsort(x, axis=-1, descending=False, stable=True)</ span >
926975
927976< span class ="sd "> Returns the indices that sort an array `x` along a specified axis.</ span >
@@ -941,6 +990,10 @@ <h1>Source code for dpctl.tensor._sorting</h1><div class="highlight"><pre>
941990< span class ="sd "> relative order of `x` values which compare as equal. If `False`,</ span >
942991< span class ="sd "> the returned array may or may not maintain the relative order of</ span >
943992< span class ="sd "> `x` values which compare as equal. Default: `True`.</ span >
993+ < span class ="sd "> kind (Optional[Literal["stable", "mergesort", "radixsort"]]):</ span >
994+ < span class ="sd "> Sorting algorithm. The default is `"stable"`, which uses parallel</ span >
995+ < span class ="sd "> merge-sort or parallel radix-sort algorithms depending on the</ span >
996+ < span class ="sd "> array data type.</ span >
944997
945998< span class ="sd "> Returns:</ span >
946999< span class ="sd "> usm_ndarray:</ span >
@@ -969,10 +1022,33 @@ <h1>Source code for dpctl.tensor._sorting</h1><div class="highlight"><pre>
9691022 < span class ="n "> axis</ span > < span class ="p "> ,</ span >
9701023 < span class ="p "> ]</ span >
9711024 < span class ="n "> arr</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> permute_dims</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> perm</ span > < span class ="p "> )</ span >
1025+ < span class ="k "> if</ span > < span class ="n "> kind</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1026+ < span class ="n "> kind</ span > < span class ="o "> =</ span > < span class ="s2 "> "stable"</ span >
1027+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> kind</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> )</ span > < span class ="ow "> or</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="p "> [</ span >
1028+ < span class ="s2 "> "stable"</ span > < span class ="p "> ,</ span >
1029+ < span class ="s2 "> "radixsort"</ span > < span class ="p "> ,</ span >
1030+ < span class ="s2 "> "mergesort"</ span > < span class ="p "> ,</ span >
1031+ < span class ="p "> ]:</ span >
1032+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span >
1033+ < span class ="s2 "> "Unsupported kind value. Expected 'stable', 'mergesort', "</ span >
1034+ < span class ="sa "> f</ span > < span class ="s2 "> "or 'radixsort', but got '</ span > < span class ="si "> {</ span > < span class ="n "> kind</ span > < span class ="si "> }</ span > < span class ="s2 "> '"</ span >
1035+ < span class ="p "> )</ span >
1036+ < span class ="k "> if</ span > < span class ="n "> kind</ span > < span class ="o "> ==</ span > < span class ="s2 "> "mergesort"</ span > < span class ="p "> :</ span >
1037+ < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_mergeargsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
1038+ < span class ="k "> elif</ span > < span class ="n "> kind</ span > < span class ="o "> ==</ span > < span class ="s2 "> "radixsort"</ span > < span class ="p "> :</ span >
1039+ < span class ="k "> if</ span > < span class ="n "> _radix_sort_dtype_supported</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> .</ span > < span class ="n "> num</ span > < span class ="p "> ):</ span >
1040+ < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_radixargsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
1041+ < span class ="k "> else</ span > < span class ="p "> :</ span >
1042+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Radix sort is not supported for </ span > < span class ="si "> {</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
1043+ < span class ="k "> else</ span > < span class ="p "> :</ span >
1044+ < span class ="n "> dt</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span >
1045+ < span class ="k "> if</ span > < span class ="n "> dt</ span > < span class ="ow "> in</ span > < span class ="p "> [</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> uint8</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> int8</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> int16</ span > < span class ="p "> ,</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> uint16</ span > < span class ="p "> ]:</ span >
1046+ < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_radixargsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
1047+ < span class ="k "> else</ span > < span class ="p "> :</ span >
1048+ < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_mergeargsort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
9721049 < span class ="n "> exec_q</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span >
9731050 < span class ="n "> _manager</ span > < span class ="o "> =</ span > < span class ="n "> du</ span > < span class ="o "> .</ span > < span class ="n "> SequentialOrderManager</ span > < span class ="p "> [</ span > < span class ="n "> exec_q</ span > < span class ="p "> ]</ span >
9741051 < span class ="n "> dep_evs</ span > < span class ="o "> =</ span > < span class ="n "> _manager</ span > < span class ="o "> .</ span > < span class ="n "> submitted_events</ span >
975- < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _argsort_descending</ span > < span class ="k "> if</ span > < span class ="n "> descending</ span > < span class ="k "> else</ span > < span class ="n "> _argsort_ascending</ span >
9761052 < span class ="n "> index_dt</ span > < span class ="o "> =</ span > < span class ="n "> ti</ span > < span class ="o "> .</ span > < span class ="n "> default_device_index_type</ span > < span class ="p "> (</ span > < span class ="n "> exec_q</ span > < span class ="p "> )</ span >
9771053 < span class ="k "> if</ span > < span class ="n "> arr</ span > < span class ="o "> .</ span > < span class ="n "> flags</ span > < span class ="o "> .</ span > < span class ="n "> c_contiguous</ span > < span class ="p "> :</ span >
9781054 < span class ="n "> res</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> empty_like</ span > < span class ="p "> (</ span > < span class ="n "> arr</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> index_dt</ span > < span class ="p "> ,</ span > < span class ="n "> order</ span > < span class ="o "> =</ span > < span class ="s2 "> "C"</ span > < span class ="p "> )</ span >
0 commit comments