diff --git a/src/ATen/native/sparse/xpu/sycl/SparseBinaryOpIntersectionKernels.cpp b/src/ATen/native/sparse/xpu/sycl/SparseBinaryOpIntersectionKernels.cpp index f55b8a6f08..08ecd1f254 100644 --- a/src/ATen/native/sparse/xpu/sycl/SparseBinaryOpIntersectionKernels.cpp +++ b/src/ATen/native/sparse/xpu/sycl/SparseBinaryOpIntersectionKernels.cpp @@ -220,6 +220,10 @@ struct ValueSelectionIntersectionKernel { template struct SparseBinaryOpIntersectionAFunctor { + // For std::is_trivially_copyable + SparseBinaryOpIntersectionAFunctor& operator=( + const SparseBinaryOpIntersectionAFunctor&) = delete; + int64_t operator()(index_t nnz_idx) const { int64_t hash = 0; if (!ptr_indices_) { @@ -256,6 +260,10 @@ struct SparseBinaryOpIntersectionAFunctor { template struct SparseBinaryOpIntersectionBFunctor { + // For std::is_trivially_copyable + SparseBinaryOpIntersectionBFunctor& operator=( + const SparseBinaryOpIntersectionBFunctor&) = delete; + index_t operator()(index_t nnz_idx) const { int64_t hash = 0; if (hash_ptr_) { diff --git a/src/ATen/native/sparse/xpu/sycl/SparseTensorKernels.cpp b/src/ATen/native/sparse/xpu/sycl/SparseTensorKernels.cpp index 2767f49a6f..31ceb4ef47 100644 --- a/src/ATen/native/sparse/xpu/sycl/SparseTensorKernels.cpp +++ b/src/ATen/native/sparse/xpu/sycl/SparseTensorKernels.cpp @@ -211,6 +211,9 @@ struct KernelLauncher { template struct FlattenIndicesFunctor { + // For std::is_trivially_copyable + FlattenIndicesFunctor& operator=(const FlattenIndicesFunctor&) = delete; + int64_t operator()(int64_t nnz_idx) const { const auto* ptr_indices_dim = ptr_indices_ + nnz_idx * indices_nnz_stride_; auto hash = static_cast(0); diff --git a/src/ATen/native/xpu/sycl/ActivationHardshrinkKernels.cpp b/src/ATen/native/xpu/sycl/ActivationHardshrinkKernels.cpp index 896d0faf80..319cc31c8a 100644 --- a/src/ATen/native/xpu/sycl/ActivationHardshrinkKernels.cpp +++ b/src/ATen/native/xpu/sycl/ActivationHardshrinkKernels.cpp @@ -11,10 +11,10 @@ struct HardshrinkFunctor { return (a >= -lambd_ && a <= lambd_) ? scalar_t(0) : a; } - HardshrinkFunctor(const scalar_t lambd) : lambd_(lambd) {} + HardshrinkFunctor(scalar_t lambd) : lambd_(lambd) {} private: - const scalar_t lambd_; + scalar_t lambd_; }; void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { diff --git a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp index d0a39ab2f8..011d1ab451 100644 --- a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp @@ -301,8 +301,8 @@ struct AdaptiveAvgPool2dBwdSLMChannelsLastKernelFunctor index_t istartW = item.get_local_id(1) + item.get_group(1) * iW; index_t iendW = std::min(istartW + iW, isizeW_); - // Stride for threads, each subgroup can reuse L1 as they go. So theoretically - // better chance to survive cache eviction. + // Stride for threads, each subgroup can reuse L1 as they go. So + // theoretically better chance to survive cache eviction. for (index_t ih = istartH; ih < iendH; ih += item.get_local_range(0)) { index_t ostartH = START_IND_INT(ih, isizeH_, osizeH_); index_t oendH = END_IND_INT(ih, isizeH_, osizeH_); diff --git a/src/ATen/native/xpu/sycl/Loops.h b/src/ATen/native/xpu/sycl/Loops.h index 8c4fbbdef5..728835b294 100644 --- a/src/ATen/native/xpu/sycl/Loops.h +++ b/src/ATen/native/xpu/sycl/Loops.h @@ -89,49 +89,40 @@ struct UnrolledElementwiseKernel { }; template -struct VectorizedElementwiseKernel { - void operator()(sycl::nd_item<1> item) const { - int grpsz = item.get_local_range(0); - int grpid = item.get_group(0); - int lid = item.get_local_id(0); - int group_work_size = vec_size * grpsz; - int remaining = numel_ - grpid * group_work_size; - - // ic_ - if (remaining < group_work_size) { - auto oc = TrivialOffsetCalculator<1>(); - auto l = at::native::memory::LoadWithoutCast(); - auto s = at::native::memory::StoreWithoutCast(); - auto policy = at::native::memory::policies::unroll< - vec_size, - array_t, - decltype(ic_), - decltype(oc), - at::native::memory::LoadWithoutCast, - at::native::memory::StoreWithoutCast>( - data_, remaining, ic_, oc, l, s, lid, grpid, grpsz); - elementwise_kernel_helper(f_, policy); - } else { - auto policy = at::native::memory::policies:: - vectorized( - data_, ic_, lid, grpid, grpsz); - elementwise_kernel_helper(f_, policy); - } +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) +void vectorized_elementwise_kernel( + int numel, + const func_t f, + array_t data, + in_calc_t ic) { + auto item = syclext::this_work_item::get_nd_item<1>(); + int grpsz = item.get_local_range(0); + int grpid = item.get_group(0); + int lid = item.get_local_id(0); + int group_work_size = vec_size * grpsz; + int remaining = numel - grpid * group_work_size; + + // ic_ + if (remaining < group_work_size) { + auto oc = TrivialOffsetCalculator<1>(); + auto l = at::native::memory::LoadWithoutCast(); + auto s = at::native::memory::StoreWithoutCast(); + auto policy = at::native::memory::policies::unroll< + vec_size, + array_t, + decltype(ic), + decltype(oc), + at::native::memory::LoadWithoutCast, + at::native::memory::StoreWithoutCast>( + data, remaining, ic, oc, l, s, lid, grpid, grpsz); + elementwise_kernel_helper(f, policy); + } else { + auto policy = + at::native::memory::policies::vectorized( + data, ic, lid, grpid, grpsz); + elementwise_kernel_helper(f, policy); } - - VectorizedElementwiseKernel( - int numel, - const func_t f, - array_t data, - in_calc_t ic) - : numel_(numel), f_(f), data_(data), ic_(ic) {} - - private: - int numel_; - const func_t f_; - array_t data_; - in_calc_t ic_; -}; +} template < int num_outputs, @@ -391,16 +382,24 @@ static inline void launch_vectorized_kernel( using traits = function_traits; auto wg_sz = syclMaxWorkItemsPerSubSlice(); -#define VEC_KER(vec_size) \ - { \ - TORCH_CHECK(max_scalar_bytes* vec_size <= 16); \ - if constexpr (max_scalar_bytes * vec_size <= 16) { \ - auto ker = \ - VectorizedElementwiseKernel( \ - N, f, data, input_calc); \ - int num_wg = ceil_div(N, wg_sz * vec_size); \ - sycl_kernel_submit(wg_sz* num_wg, wg_sz, getCurrentSYCLQueue(), ker); \ - } \ +#define VEC_KER(vec_size) \ + { \ + TORCH_CHECK(max_scalar_bytes* vec_size <= 16); \ + if constexpr (max_scalar_bytes * vec_size <= 16) { \ + int num_wg = ceil_div(N, wg_sz * vec_size); \ + sycl_kernel_submit>( \ + wg_sz * num_wg, \ + wg_sz, \ + getCurrentSYCLQueue(), \ + N, \ + f, \ + data, \ + input_calc); \ + } \ } switch (vec_size) { diff --git a/src/ATen/native/xpu/sycl/Reduce.h b/src/ATen/native/xpu/sycl/Reduce.h index 0c713ee51d..05c804202e 100644 --- a/src/ATen/native/xpu/sycl/Reduce.h +++ b/src/ATen/native/xpu/sycl/Reduce.h @@ -1124,6 +1124,13 @@ int get_output_vec_size(at::TensorIterator& iter) { return std::min(vec_size, vt1); } +template < + typename scalar_t, + typename out_scalar_t, + int vt0, + int vt1, + typename ops_t, + typename ident_t> struct ReduceKernelEmptyFunctor { char operator()() const { return 0; @@ -1374,7 +1381,8 @@ inline void gpu_reduce_kernel( config.semaphore_size(), at::TensorOptions().dtype(kChar).device(kXPU)); at::detail::Array data; data[0] = (char*)semaphores.data_ptr(); - ReduceKernelEmptyFunctor fn; + ReduceKernelEmptyFunctor + fn; int vec_size = at::native::memory::can_vectorize_up_to(data); auto ic = TrivialOffsetCalculator(); launch_vectorized_kernel(config.semaphore_size(), fn, data, ic, vec_size); diff --git a/src/ATen/native/xpu/sycl/SortingKernels.h b/src/ATen/native/xpu/sycl/SortingKernels.h index cce01ba4b5..2aa1ae65fc 100644 --- a/src/ATen/native/xpu/sycl/SortingKernels.h +++ b/src/ATen/native/xpu/sycl/SortingKernels.h @@ -316,13 +316,6 @@ void segmented_radix_sort_pairs_downsweep_kernel( // ======================= large sort ======================= -template -struct ABBufferCopyFunctor { - scalar_t operator()(scalar_t x) const { - return x; - } -}; - template < typename key_t, typename value_t, @@ -416,20 +409,18 @@ void segmented_radix_sort_pairs_kernel( auto input_calc = TrivialOffsetCalculator<2>(); at::detail::Array data; if (keys_out) { - data[0] = (char*)keys_out; - data[1] = (char*)keys_temp; - auto fn = ABBufferCopyFunctor(); - auto vec_size = memory::can_vectorize_up_to(data); - launch_vectorized_kernel( - num_segments * num_elements, fn, data, input_calc, vec_size); + auto q = at::xpu::getCurrentSYCLQueue(); + q.memcpy( + (void*)keys_out, + (void*)keys_temp, + sizeof(key_t) * num_segments * num_elements); } if (values_out) { - data[0] = (char*)values_out; - data[1] = (char*)values_temp; - auto fn = ABBufferCopyFunctor(); - auto vec_size = memory::can_vectorize_up_to(data); - launch_vectorized_kernel( - num_segments * num_elements, fn, data, input_calc, vec_size); + auto q = at::xpu::getCurrentSYCLQueue(); + q.memcpy( + (void*)values_out, + (void*)values_temp, + sizeof(value_t) * num_segments * num_elements); } } } diff --git a/src/ATen/native/xpu/sycl/TensorCompare.cpp b/src/ATen/native/xpu/sycl/TensorCompare.cpp index 7259373bb8..d25e9a5440 100644 --- a/src/ATen/native/xpu/sycl/TensorCompare.cpp +++ b/src/ATen/native/xpu/sycl/TensorCompare.cpp @@ -35,47 +35,6 @@ struct ClampFunctor { } }; -template -struct ClampScalarFunctor { - using opmath_t = at::opmath_type; - scalar_t operator()(scalar_t v) const { - if (_isnan(static_cast(v))) { - return v; - } else if (minmax_ == at::native::detail::ClampLimits::Min) { - return std::max(static_cast(v), lim0_val_); - } else if (minmax_ == at::native::detail::ClampLimits::Max) { - return std::min(static_cast(v), lim0_val_); - } else { - return std::min(std::max(static_cast(v), lim0_val_), lim1_val_); - } - } - ClampScalarFunctor( - opmath_t lim0_val, - opmath_t lim1_val, - at::native::detail::ClampLimits minmax) - : lim0_val_(lim0_val), lim1_val_(lim1_val), minmax_(minmax) {} - - private: - opmath_t lim0_val_; - opmath_t lim1_val_; - at::native::detail::ClampLimits minmax_; -}; - -void inline launch_clamp_scalar( - TensorIteratorBase& iter, - Scalar lim0, - Scalar lim1, - at::native::detail::ClampLimits minmax) { - AT_DISPATCH_ALL_TYPES_AND2( - kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_xpu", [&] { - using opmath_t = at::opmath_type; - auto lim0_val = lim0.to(); - auto lim1_val = lim1.to(); - gpu_kernel( - iter, ClampScalarFunctor(lim0_val, lim1_val, minmax)); - }); -} - } // namespace xpu } // namespace native } // namespace at diff --git a/src/comm/SYCLHelpers.h b/src/comm/SYCLHelpers.h index 48390229f4..34d33698cc 100644 --- a/src/comm/SYCLHelpers.h +++ b/src/comm/SYCLHelpers.h @@ -3,6 +3,9 @@ #include #include +namespace syclext = sycl::ext::oneapi; +namespace syclexp = sycl::ext::oneapi::experimental; + // sycl access address space static constexpr auto sycl_priv_space = sycl::access::address_space::private_space; @@ -139,6 +142,22 @@ sycl_kernel_submit( q.submit(cgf); } +// For SYCL free function +template +static inline void sycl_kernel_submit( + int64_t global_range, + int64_t local_range, + ::sycl::queue q, + Kargs... args) { + sycl::context ctxt = q.get_context(); + auto exe_bndl = + syclexp::get_kernel_bundle(ctxt); + sycl::kernel ker = exe_bndl.template ext_oneapi_get_kernel(); + syclexp::launch_config cfg{::sycl::nd_range<1>( + ::sycl::range<1>(global_range), ::sycl::range<1>(local_range))}; + syclexp::nd_launch(q, cfg, ker, args...); +} + #define SYCL_KERNEL_STRING(var, str) \ static const __attribute__((opencl_constant)) char var[] = str; #define SYCL_KERNEL_PRINTF sycl::ext::oneapi::experimental::printf