@@ -79,12 +79,23 @@ def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
7979 raise TypeError (
8080 f"Expected type dpctl.tensor.usm_ndarray, got { type (x )} "
8181 )
82+ if not isinstance (kind , str ) or kind not in [
83+ "stable" ,
84+ "radixsort" ,
85+ "mergesort" ,
86+ ]:
87+ raise ValueError (
88+ "Unsupported kind value. Expected 'stable', 'mergesort', "
89+ f"or 'radixsort', but got '{ kind } '"
90+ )
8291 nd = x .ndim
8392 if nd == 0 :
8493 axis = normalize_axis_index (axis , ndim = 1 , msg_prefix = "axis" )
8594 return dpt .copy (x , order = "C" )
8695 else :
8796 axis = normalize_axis_index (axis , ndim = nd , msg_prefix = "axis" )
97+ if x .size == 1 :
98+ return dpt .copy (x , order = "C" )
8899 a1 = axis + 1
89100 if a1 == nd :
90101 perm = list (range (nd ))
@@ -96,15 +107,6 @@ def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
96107 arr = dpt .permute_dims (x , perm )
97108 if kind is None :
98109 kind = "stable"
99- if not isinstance (kind , str ) or kind not in [
100- "stable" ,
101- "radixsort" ,
102- "mergesort" ,
103- ]:
104- raise ValueError (
105- "Unsupported kind value. Expected 'stable', 'mergesort', "
106- f"or 'radixsort', but got '{ kind } '"
107- )
108110 if kind == "mergesort" :
109111 impl_fn = _get_mergesort_impl_fn (descending )
110112 elif kind == "radixsort" :
0 commit comments