Skip to content

Implement SYCL free function style for vectorized loops kernel #1906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ struct ValueSelectionIntersectionKernel {

template <typename index_t, typename hash_coeffs_t>
struct SparseBinaryOpIntersectionAFunctor {
// For std::is_trivially_copyable
SparseBinaryOpIntersectionAFunctor& operator=(
const SparseBinaryOpIntersectionAFunctor&) = delete;

int64_t operator()(index_t nnz_idx) const {
int64_t hash = 0;
if (!ptr_indices_) {
Expand Down Expand Up @@ -256,6 +260,10 @@ struct SparseBinaryOpIntersectionAFunctor {

template <typename index_t, typename hash_coeffs_t, typename hash_t = int64_t>
struct SparseBinaryOpIntersectionBFunctor {
// For std::is_trivially_copyable
SparseBinaryOpIntersectionBFunctor& operator=(
const SparseBinaryOpIntersectionBFunctor&) = delete;

index_t operator()(index_t nnz_idx) const {
int64_t hash = 0;
if (hash_ptr_) {
Expand Down
3 changes: 3 additions & 0 deletions src/ATen/native/sparse/xpu/sycl/SparseTensorKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ struct KernelLauncher {

template <typename index_t, typename hash_coeffs_t>
struct FlattenIndicesFunctor {
// For std::is_trivially_copyable
FlattenIndicesFunctor& operator=(const FlattenIndicesFunctor&) = delete;

int64_t operator()(int64_t nnz_idx) const {
const auto* ptr_indices_dim = ptr_indices_ + nnz_idx * indices_nnz_stride_;
auto hash = static_cast<int64_t>(0);
Expand Down
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/ActivationHardshrinkKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ struct HardshrinkFunctor {
return (a >= -lambd_ && a <= lambd_) ? scalar_t(0) : a;
}

HardshrinkFunctor(const scalar_t lambd) : lambd_(lambd) {}
HardshrinkFunctor(scalar_t lambd) : lambd_(lambd) {}

private:
const scalar_t lambd_;
scalar_t lambd_;
};

void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) {
Expand Down
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ struct AdaptiveAvgPool2dBwdSLMChannelsLastKernelFunctor
index_t istartW = item.get_local_id(1) + item.get_group(1) * iW;
index_t iendW = std::min(istartW + iW, isizeW_);

// Stride for threads, each subgroup can reuse L1 as they go. So theoretically
// better chance to survive cache eviction.
// Stride for threads, each subgroup can reuse L1 as they go. So
// theoretically better chance to survive cache eviction.
for (index_t ih = istartH; ih < iendH; ih += item.get_local_range(0)) {
index_t ostartH = START_IND_INT(ih, isizeH_, osizeH_);
index_t oendH = END_IND_INT(ih, isizeH_, osizeH_);
Expand Down
103 changes: 51 additions & 52 deletions src/ATen/native/xpu/sycl/Loops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,49 +89,40 @@ struct UnrolledElementwiseKernel {
};

template <int vec_size, typename func_t, typename array_t, typename in_calc_t>
struct VectorizedElementwiseKernel {
void operator()(sycl::nd_item<1> item) const {
int grpsz = item.get_local_range(0);
int grpid = item.get_group(0);
int lid = item.get_local_id(0);
int group_work_size = vec_size * grpsz;
int remaining = numel_ - grpid * group_work_size;

// ic_
if (remaining < group_work_size) {
auto oc = TrivialOffsetCalculator<1>();
auto l = at::native::memory::LoadWithoutCast();
auto s = at::native::memory::StoreWithoutCast();
auto policy = at::native::memory::policies::unroll<
vec_size,
array_t,
decltype(ic_),
decltype(oc),
at::native::memory::LoadWithoutCast,
at::native::memory::StoreWithoutCast>(
data_, remaining, ic_, oc, l, s, lid, grpid, grpsz);
elementwise_kernel_helper<vec_size>(f_, policy);
} else {
auto policy = at::native::memory::policies::
vectorized<vec_size, array_t, in_calc_t>(
data_, ic_, lid, grpid, grpsz);
elementwise_kernel_helper<vec_size>(f_, policy);
}
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
void vectorized_elementwise_kernel(
int numel,
const func_t f,
array_t data,
in_calc_t ic) {
auto item = syclext::this_work_item::get_nd_item<1>();
int grpsz = item.get_local_range(0);
int grpid = item.get_group(0);
int lid = item.get_local_id(0);
int group_work_size = vec_size * grpsz;
int remaining = numel - grpid * group_work_size;

// ic_
if (remaining < group_work_size) {
auto oc = TrivialOffsetCalculator<1>();
auto l = at::native::memory::LoadWithoutCast();
auto s = at::native::memory::StoreWithoutCast();
auto policy = at::native::memory::policies::unroll<
vec_size,
array_t,
decltype(ic),
decltype(oc),
at::native::memory::LoadWithoutCast,
at::native::memory::StoreWithoutCast>(
data, remaining, ic, oc, l, s, lid, grpid, grpsz);
elementwise_kernel_helper<vec_size>(f, policy);
} else {
auto policy =
at::native::memory::policies::vectorized<vec_size, array_t, in_calc_t>(
data, ic, lid, grpid, grpsz);
elementwise_kernel_helper<vec_size>(f, policy);
}

VectorizedElementwiseKernel(
int numel,
const func_t f,
array_t data,
in_calc_t ic)
: numel_(numel), f_(f), data_(data), ic_(ic) {}

private:
int numel_;
const func_t f_;
array_t data_;
in_calc_t ic_;
};
}

template <
int num_outputs,
Expand Down Expand Up @@ -391,16 +382,24 @@ static inline void launch_vectorized_kernel(
using traits = function_traits<func_t>;
auto wg_sz = syclMaxWorkItemsPerSubSlice();

#define VEC_KER(vec_size) \
{ \
TORCH_CHECK(max_scalar_bytes* vec_size <= 16); \
if constexpr (max_scalar_bytes * vec_size <= 16) { \
auto ker = \
VectorizedElementwiseKernel<vec_size, func_t, array_t, in_calc_t>( \
N, f, data, input_calc); \
int num_wg = ceil_div<int>(N, wg_sz * vec_size); \
sycl_kernel_submit(wg_sz* num_wg, wg_sz, getCurrentSYCLQueue(), ker); \
} \
#define VEC_KER(vec_size) \
{ \
TORCH_CHECK(max_scalar_bytes* vec_size <= 16); \
if constexpr (max_scalar_bytes * vec_size <= 16) { \
int num_wg = ceil_div<int>(N, wg_sz * vec_size); \
sycl_kernel_submit<vectorized_elementwise_kernel< \
vec_size, \
func_t, \
array_t, \
in_calc_t>>( \
wg_sz * num_wg, \
wg_sz, \
getCurrentSYCLQueue(), \
N, \
f, \
data, \
input_calc); \
} \
}

switch (vec_size) {
Expand Down
10 changes: 9 additions & 1 deletion src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,13 @@ int get_output_vec_size(at::TensorIterator& iter) {
return std::min(vec_size, vt1);
}

template <
typename scalar_t,
typename out_scalar_t,
int vt0,
int vt1,
typename ops_t,
typename ident_t>
struct ReduceKernelEmptyFunctor {
char operator()() const {
return 0;
Expand Down Expand Up @@ -1374,7 +1381,8 @@ inline void gpu_reduce_kernel(
config.semaphore_size(), at::TensorOptions().dtype(kChar).device(kXPU));
at::detail::Array<char*, 1> data;
data[0] = (char*)semaphores.data_ptr();
ReduceKernelEmptyFunctor fn;
ReduceKernelEmptyFunctor<scalar_t, out_scalar_t, vt0, vt1, ops_t, ident_t>
fn;
int vec_size = at::native::memory::can_vectorize_up_to<decltype(fn)>(data);
auto ic = TrivialOffsetCalculator<traits::arity>();
launch_vectorized_kernel(config.semaphore_size(), fn, data, ic, vec_size);
Expand Down
29 changes: 10 additions & 19 deletions src/ATen/native/xpu/sycl/SortingKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,6 @@ void segmented_radix_sort_pairs_downsweep_kernel(

// ======================= large sort =======================

template <typename scalar_t>
struct ABBufferCopyFunctor {
scalar_t operator()(scalar_t x) const {
return x;
}
};

template <
typename key_t,
typename value_t,
Expand Down Expand Up @@ -416,20 +409,18 @@ void segmented_radix_sort_pairs_kernel(
auto input_calc = TrivialOffsetCalculator<2>();
at::detail::Array<char*, 2> data;
if (keys_out) {
data[0] = (char*)keys_out;
data[1] = (char*)keys_temp;
auto fn = ABBufferCopyFunctor<key_t>();
auto vec_size = memory::can_vectorize_up_to<decltype(fn)>(data);
launch_vectorized_kernel(
num_segments * num_elements, fn, data, input_calc, vec_size);
auto q = at::xpu::getCurrentSYCLQueue();
q.memcpy(
(void*)keys_out,
(void*)keys_temp,
sizeof(key_t) * num_segments * num_elements);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we replace EU copy with memcpy here? better performance?

}
if (values_out) {
data[0] = (char*)values_out;
data[1] = (char*)values_temp;
auto fn = ABBufferCopyFunctor<value_t>();
auto vec_size = memory::can_vectorize_up_to<decltype(fn)>(data);
launch_vectorized_kernel(
num_segments * num_elements, fn, data, input_calc, vec_size);
auto q = at::xpu::getCurrentSYCLQueue();
q.memcpy(
(void*)values_out,
(void*)values_temp,
sizeof(value_t) * num_segments * num_elements);
}
}
}
Expand Down
41 changes: 0 additions & 41 deletions src/ATen/native/xpu/sycl/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,6 @@ struct ClampFunctor {
}
};

template <typename scalar_t>
struct ClampScalarFunctor {
using opmath_t = at::opmath_type<scalar_t>;
scalar_t operator()(scalar_t v) const {
if (_isnan(static_cast<opmath_t>(v))) {
return v;
} else if (minmax_ == at::native::detail::ClampLimits::Min) {
return std::max(static_cast<opmath_t>(v), lim0_val_);
} else if (minmax_ == at::native::detail::ClampLimits::Max) {
return std::min(static_cast<opmath_t>(v), lim0_val_);
} else {
return std::min(std::max(static_cast<opmath_t>(v), lim0_val_), lim1_val_);
}
}
ClampScalarFunctor(
opmath_t lim0_val,
opmath_t lim1_val,
at::native::detail::ClampLimits minmax)
: lim0_val_(lim0_val), lim1_val_(lim1_val), minmax_(minmax) {}

private:
opmath_t lim0_val_;
opmath_t lim1_val_;
at::native::detail::ClampLimits minmax_;
};

void inline launch_clamp_scalar(
TensorIteratorBase& iter,
Scalar lim0,
Scalar lim1,
at::native::detail::ClampLimits minmax) {
AT_DISPATCH_ALL_TYPES_AND2(
kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_xpu", [&] {
using opmath_t = at::opmath_type<scalar_t>;
auto lim0_val = lim0.to<opmath_t>();
auto lim1_val = lim1.to<opmath_t>();
gpu_kernel(
iter, ClampScalarFunctor<scalar_t>(lim0_val, lim1_val, minmax));
});
}

} // namespace xpu
} // namespace native
} // namespace at
25 changes: 25 additions & 0 deletions src/comm/SYCLHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <comm/Scalar.h>
#include <sycl/sycl.hpp>

namespace syclext = sycl::ext::oneapi;
namespace syclexp = sycl::ext::oneapi::experimental;

// sycl access address space
static constexpr auto sycl_priv_space =
sycl::access::address_space::private_space;
Expand Down Expand Up @@ -139,6 +142,28 @@ sycl_kernel_submit(
q.submit(cgf);
}

// For SYCL free function
template <auto* kptr, typename... Kargs>
static inline void sycl_kernel_submit(
int64_t global_range,
int64_t local_range,
::sycl::queue q,
Kargs... args) {
sycl::context ctxt = q.get_context();
auto exe_bndl =
syclexp::get_kernel_bundle<kptr, sycl::bundle_state::executable>(ctxt);
sycl::kernel ker = exe_bndl.template ext_oneapi_get_kernel<kptr>();
auto args_tuple = std::make_tuple(std::forward<Kargs>(args)...);
auto cgf = [&](::sycl::handler& cgh) {
std::apply([&](auto&&... args_) { cgh.set_args(args_...); }, args_tuple);
cgh.parallel_for(
::sycl::nd_range<1>(
::sycl::range<1>(global_range), ::sycl::range<1>(local_range)),
ker);
};
q.submit(cgf);
}

#define SYCL_KERNEL_STRING(var, str) \
static const __attribute__((opencl_constant)) char var[] = str;
#define SYCL_KERNEL_PRINTF sycl::ext::oneapi::experimental::printf
Expand Down
Loading