@@ -891,6 +891,15 @@ <h1>Source code for dpctl.tensor._sorting</h1><div class="highlight"><pre>
891891 < span class ="n "> arr</ span > < span class ="o "> =</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> permute_dims</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> perm</ span > < span class ="p "> )</ span >
892892 < span class ="k "> if</ span > < span class ="n "> kind</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
893893 < span class ="n "> kind</ span > < span class ="o "> =</ span > < span class ="s2 "> "stable"</ span >
894+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> kind</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> )</ span > < span class ="ow "> or</ span > < span class ="n "> kind</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="p "> [</ span >
895+ < span class ="s2 "> "stable"</ span > < span class ="p "> ,</ span >
896+ < span class ="s2 "> "radixsort"</ span > < span class ="p "> ,</ span >
897+ < span class ="s2 "> "mergesort"</ span > < span class ="p "> ,</ span >
898+ < span class ="p "> ]:</ span >
899+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span >
900+ < span class ="s2 "> "Unsupported kind value. Expected 'stable', 'mergesort', "</ span >
901+ < span class ="sa "> f</ span > < span class ="s2 "> "or 'radixsort', but got '</ span > < span class ="si "> {</ span > < span class ="n "> kind</ span > < span class ="si "> }</ span > < span class ="s2 "> '"</ span >
902+ < span class ="p "> )</ span >
894903 < span class ="k "> if</ span > < span class ="n "> kind</ span > < span class ="o "> ==</ span > < span class ="s2 "> "mergesort"</ span > < span class ="p "> :</ span >
895904 < span class ="n "> impl_fn</ span > < span class ="o "> =</ span > < span class ="n "> _get_mergesort_impl_fn</ span > < span class ="p "> (</ span > < span class ="n "> descending</ span > < span class ="p "> )</ span >
896905 < span class ="k "> elif</ span > < span class ="n "> kind</ span > < span class ="o "> ==</ span > < span class ="s2 "> "radixsort"</ span > < span class ="p "> :</ span >
@@ -989,6 +998,10 @@ <h1>Source code for dpctl.tensor._sorting</h1><div class="highlight"><pre>
989998 < span class ="p "> )</ span >
990999 < span class ="k "> else</ span > < span class ="p "> :</ span >
9911000 < span class ="n "> axis</ span > < span class ="o "> =</ span > < span class ="n "> normalize_axis_index</ span > < span class ="p "> (</ span > < span class ="n "> axis</ span > < span class ="p "> ,</ span > < span class ="n "> ndim</ span > < span class ="o "> =</ span > < span class ="n "> nd</ span > < span class ="p "> ,</ span > < span class ="n "> msg_prefix</ span > < span class ="o "> =</ span > < span class ="s2 "> "axis"</ span > < span class ="p "> )</ span >
1001+ < span class ="k "> if</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> size</ span > < span class ="o "> ==</ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
1002+ < span class ="k "> return</ span > < span class ="n "> dpt</ span > < span class ="o "> .</ span > < span class ="n "> zeros_like</ span > < span class ="p "> (</ span >
1003+ < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> ti</ span > < span class ="o "> .</ span > < span class ="n "> default_device_index_type</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> sycl_queue</ span > < span class ="p "> ),</ span > < span class ="n "> order</ span > < span class ="o "> =</ span > < span class ="s2 "> "C"</ span >
1004+ < span class ="p "> )</ span >
9921005 < span class ="n "> a1</ span > < span class ="o "> =</ span > < span class ="n "> axis</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span >
9931006 < span class ="k "> if</ span > < span class ="n "> a1</ span > < span class ="o "> ==</ span > < span class ="n "> nd</ span > < span class ="p "> :</ span >
9941007 < span class ="n "> perm</ span > < span class ="o "> =</ span > < span class ="nb "> list</ span > < span class ="p "> (</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> nd</ span > < span class ="p "> ))</ span >
0 commit comments