1515#include < ATen/native/CompositeRandomAccessor.h>
1616#include < ATen/native/TopKImpl.h>
1717#include < c10/core/WrapDimMinimal.h>
18+ #include < c10/util/SmallBuffer.h>
1819#include < c10/util/irange.h>
20+
1921#ifdef USE_FBGEMM
2022#include < fbgemm/Utils.h>
2123#endif
2224
25+ #if USE_X86_SIMD_SORT && (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2))
26+ #define XSS_COMPILE_TIME_SUPPORTED
27+ #include < src/x86simdsort-static-incl.h>
28+ #endif
29+
2330namespace at ::native {
2431
2532namespace {
@@ -119,6 +126,7 @@ static void parallel_sort1d_kernel(
119126 std::vector<int64_t > tmp_vals (elements);
120127 const scalar_t * sorted_keys = nullptr ;
121128 const int64_t * sorted_vals = nullptr ;
129+
122130 std::tie (sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel (
123131 keys,
124132 vals,
@@ -167,6 +175,116 @@ static inline void sort_kernel_impl(const value_accessor_t& value_accessor,
167175 }
168176}
169177
178+ #if defined(XSS_COMPILE_TIME_SUPPORTED)
179+
180+ #define AT_DISPATCH_CASE_XSS_TYPES (...) \
181+ AT_DISPATCH_CASE (at::ScalarType::Long, __VA_ARGS__) \
182+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
183+ AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
184+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
185+
186+ #define AT_DISPATCH_XSS_TYPES (TYPE, NAME, ...) \
187+ AT_DISPATCH_SWITCH (TYPE, NAME, AT_DISPATCH_CASE_XSS_TYPES(__VA_ARGS__))
188+
189+ static bool can_use_xss_sort(const TensorBase& values, const TensorBase& indices, int64_t dim, const bool stable) {
190+ // xss_sort is not a stable sort
191+ if (stable) return false ;
192+
193+ auto type = values.scalar_type ();
194+ if (! (type == ScalarType::Long || type == ScalarType::Int || type == ScalarType::Double || type == ScalarType::Float)) return false ;
195+
196+ return true ;
197+ }
198+
199+ static bool xss_sort_preferred (const TensorBase& values, const bool descending) {
200+ #if defined(XSS_USE_OPENMP) || !defined(USE_FBGEMM)
201+ return true ;
202+ #else
203+ // Without OpenMP support for x86-simd-sort, fbgemm radix sort is faster when it can be used
204+ return !can_use_radix_sort (values, descending);
205+ #endif
206+ }
207+
208+ static void xss_sort_kernel (
209+ const TensorBase& values,
210+ const TensorBase& indices,
211+ int64_t dim,
212+ bool descending) {
213+ auto iter = TensorIteratorConfig ()
214+ .check_all_same_dtype (false )
215+ .resize_outputs (false )
216+ .declare_static_shape (values.sizes (), /* squash_dims=*/ dim)
217+ .add_output (values)
218+ .add_output (indices)
219+ .build ();
220+
221+ using index_t = int64_t ;
222+
223+ AT_DISPATCH_XSS_TYPES (values.scalar_type (), " xss_sort_kernel" , [&] {
224+
225+ auto values_dim_stride = values.stride (dim);
226+ auto indices_dim_stride = indices.stride (dim);
227+ auto dim_size = values.size (dim);
228+
229+ auto loop = [&](char ** data, const int64_t * strides, int64_t n) {
230+ auto * values_data_bytes = data[0 ];
231+ auto * indices_data_bytes = data[1 ];
232+
233+ if (values_data_bytes==nullptr || indices_data_bytes==nullptr ){
234+ return ;
235+ }
236+
237+ if (values_dim_stride == 1 && indices_dim_stride == 1 ){
238+ for (const auto i C10_UNUSED : c10::irange (n)) {
239+ x86simdsortStatic::keyvalue_qsort<scalar_t , index_t >(
240+ reinterpret_cast <scalar_t *>(values_data_bytes),
241+ reinterpret_cast <index_t *>(indices_data_bytes),
242+ dim_size,
243+ true ,
244+ descending);
245+
246+ values_data_bytes += strides[0 ];
247+ indices_data_bytes += strides[1 ];
248+ }
249+ }else {
250+ c10::SmallBuffer<scalar_t , 0 > tmp_values (dim_size);
251+ c10::SmallBuffer<index_t , 0 > tmp_indices (dim_size);
252+
253+ for (const auto i : c10::irange (n)) {
254+ TensorAccessor<scalar_t , 1 > mode_values_acc (
255+ reinterpret_cast <scalar_t *>(data[0 ] + i * strides[0 ]),
256+ &dim_size, &values_dim_stride);
257+ TensorAccessor<index_t , 1 > mode_indices_acc (
258+ reinterpret_cast <index_t *>(data[1 ] + i * strides[1 ]),
259+ &dim_size, &indices_dim_stride);
260+
261+ for (const auto j : c10::irange (dim_size)) {
262+ tmp_values[j] = mode_values_acc[j];
263+ tmp_indices[j] = j;
264+ }
265+
266+ x86simdsortStatic::keyvalue_qsort<scalar_t , index_t >(
267+ tmp_values.data (),
268+ tmp_indices.data (),
269+ dim_size,
270+ true ,
271+ descending);
272+
273+ for (const auto j : c10::irange (dim_size)) {
274+ mode_values_acc[j] = tmp_values[j];
275+ mode_indices_acc[j] = tmp_indices[j];
276+ }
277+ }
278+ }
279+ };
280+
281+ int64_t grain_size = internal::GRAIN_SIZE / std::max (int64_t {1 }, dim_size);
282+ iter.for_each (loop, /* grain_size=*/ grain_size);
283+
284+ });
285+ }
286+ #endif
287+
170288static void sort_kernel (
171289 const TensorBase& self,
172290 const TensorBase& values,
@@ -181,6 +299,14 @@ static void sort_kernel(
181299 // https://github.com/pytorch/pytorch/issues/91420
182300 return ;
183301 }
302+
303+ #if defined(XSS_COMPILE_TIME_SUPPORTED)
304+ if (can_use_xss_sort (values, indices, dim, stable) && xss_sort_preferred (values, descending)){
305+ xss_sort_kernel (values, indices, dim, descending);
306+ return ;
307+ }
308+ #endif
309+
184310#ifdef USE_FBGEMM
185311 if (can_use_radix_sort (values, descending)) {
186312 parallel_sort1d_kernel (values, indices);
@@ -232,6 +358,7 @@ static void topk_kernel(
232358 int64_t dim,
233359 bool largest,
234360 bool sorted) {
361+
235362 auto sizes = self.sizes ();
236363 auto iter = TensorIteratorConfig ()
237364 .check_all_same_dtype (false )
@@ -266,7 +393,7 @@ static void topk_kernel(
266393
267394} // anonymous namespace
268395
269- REGISTER_DISPATCH (sort_stub, &sort_kernel);
270- REGISTER_DISPATCH (topk_stub, &topk_kernel);
396+ ALSO_REGISTER_AVX512_DISPATCH (sort_stub, &sort_kernel);
397+ ALSO_REGISTER_AVX512_DISPATCH (topk_stub, &topk_kernel);
271398
272399} // at::native
0 commit comments