Skip to content

Implement SYCL free function style for Reduction kernel #1927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: feng/free_func_vector_loops
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ATen/native/xpu/sycl/Loops.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ static inline void launch_vectorized_kernel(
wg_sz * num_wg, \
wg_sz, \
getCurrentSYCLQueue(), \
0, \
N, \
f, \
data, \
Expand Down
96 changes: 46 additions & 50 deletions src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ template <class arg_t, class item_t, class CombineFunc, int out_vec_sz = 1>
inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
item_t item,
int wg_size,
sycl_local_ptr<void> shared,
char* shared,
at::detail::Array<arg_t, out_vec_sz> value,
CombineFunc combine) {
using vec_t = at::detail::Array<arg_t, out_vec_sz>;
sycl_local_ptr<vec_t> shared_(shared);
vec_t* shared_ = reinterpret_cast<vec_t*>(shared_);
int l_x = item.get_local_linear_id();
// int dim_x = wg_size;
auto sg = item.get_sub_group();
Expand Down Expand Up @@ -101,11 +101,11 @@ inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
template <class arg_t, class item_t, class CombineFunc, int out_vec_sz = 1>
inline at::detail::Array<arg_t, out_vec_sz> group_x_reduce(
item_t item,
sycl_local_ptr<void> shared,
char* shared,
at::detail::Array<arg_t, out_vec_sz> value,
CombineFunc combine) {
using vec_t = at::detail::Array<arg_t, out_vec_sz>;
sycl_local_ptr<vec_t> shared_(shared);
vec_t* shared_ = reinterpret_cast<vec_t*>(shared_);
int l_x = item.get_local_id(1), l_y = item.get_local_id(0);
int g_x = item.get_local_range(1);
int dim_x = g_x;
Expand Down Expand Up @@ -143,11 +143,11 @@ inline at::detail::Array<arg_t, out_vec_sz> group_x_reduce(
template <class arg_t, class item_t, class CombineFunc, int out_vec_sz = 1>
inline at::detail::Array<arg_t, out_vec_sz> group_y_reduce(
item_t item,
sycl_local_ptr<void> shared,
char* shared,
at::detail::Array<arg_t, out_vec_sz> value,
CombineFunc combine) {
using vec_t = at::detail::Array<arg_t, out_vec_sz>;
sycl_local_ptr<vec_t> shared_(shared);
vec_t* shared_ = reinterpret_cast<vec_t*>(shared_);
int l_x = item.get_local_id(1), l_y = item.get_local_id(0);
int g_x = item.get_local_range(1);
int dim_y = item.get_local_range(0);
Expand Down Expand Up @@ -238,9 +238,9 @@ struct ReduceConfig {
int input_vec_size = 1;
int output_vec_size = 1;

template <typename T, class KernelClass>
template <typename T, auto* K>
void set_group_dimension(int64_t dim0, int64_t dim1) {
auto max_wg_sz = syclMaxWorkGroupSize<KernelClass>();
auto max_wg_sz = syclMaxWorkGroupSize<K>();
// Bypass reduction on SLM by sparing workload to other SGs. As the
// result, reduction of small shape input only requires some shift
// operations in side of SG. It is functional WA. We got case failures on
Expand Down Expand Up @@ -377,26 +377,10 @@ struct ReduceConfig {
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);

template <int output_vec_size, typename R>
class ReduceKernel : public __SYCL_KER_CONFIG_CONVENTION__ {
public:
ReduceKernel(R reduction, sycl::range<1> slm_sz)
: reduction_(reduction), slm_sz_(slm_sz), shared_(), finished_() {}

void operator()(sycl::nd_item<2> pos) const {
reduction_.template run<output_vec_size>(pos, shared_, finished_);
}

void sycl_ker_config_convention(sycl::handler& cgh) {
shared_ = sycl_local_acc_t<char>(slm_sz_, cgh);
finished_ = sycl_local_acc_t<bool>({1}, cgh);
}

private:
R reduction_;
sycl::range<1> slm_sz_;
sycl_local_acc_t<char> shared_; /* group tree reduce */
sycl_local_acc_t<bool> finished_; /* last WG flag to broadcast inner WG */
};
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>))
void reduce_kernel(R reduction) {
reduction.template run<output_vec_size>();
}

template <typename index_t>
static OffsetCalculator<2, index_t> make_output_calculator(
Expand Down Expand Up @@ -528,10 +512,9 @@ struct ReduceOp {
}

template <int output_vec_size>
void run(
sycl::nd_item<2> pos,
sycl_local_ptr<char> shared,
sycl_local_ptr<bool> finished) const {
void run() const {
auto pos = syclext::this_work_item::get_nd_item<2>();
char* shared = (char*)syclexp::get_work_group_scratch_memory();
index_t output_idx = config.output_idx<output_vec_size>(pos);
index_t input_idx = config.input_idx(pos);
auto base_offsets1 = output_calc.get(output_idx)[1];
Expand Down Expand Up @@ -598,7 +581,7 @@ struct ReduceOp {
}

if (config.should_global_reduce()) {
value = global_reduce<output_vec_size>(pos, value, acc, shared, finished);
value = global_reduce<output_vec_size>(pos, value, acc, shared);
} else if (config.should_store(pos, output_idx)) {
if (accumulate) {
#pragma unroll
Expand Down Expand Up @@ -823,18 +806,19 @@ struct ReduceOp {
}

// In/out from slm pointers
void mark_group_finished(sycl::nd_item<2> pos, sycl_local_ptr<bool> finished)
const {
bool mark_group_finished(sycl::nd_item<2> pos) const {
syclexp::work_group_static<bool> finished;
pos.barrier(sycl_local_fence);

if (pos.get_local_linear_id() == 0) {
sycl_atomic_ref_rlx_dev_global_t<int> count(semaphores[pos.get_group(1)]);
int prev_groups_finished = count.fetch_add(
1, sycl_mem_odr_acq_rel
/* , default memory scope is device */);
finished[0] = (prev_groups_finished == (int)(pos.get_group_range(0) - 1));
finished = (prev_groups_finished == (int)(pos.get_group_range(0) - 1));
}
pos.barrier(sycl_local_fence);
return finished;
}

template <int output_vec_size, bool can_acc>
Expand Down Expand Up @@ -922,8 +906,7 @@ struct ReduceOp {
sycl::nd_item<2> pos,
at::detail::Array<arg_t, output_vec_size> value,
at::detail::Array<arg_t, output_vec_size>* acc,
sycl_local_ptr<char> shared_memory,
sycl_local_ptr<bool> is_last_group_done) const {
char* shared_memory) const {
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
Expand All @@ -945,9 +928,9 @@ struct ReduceOp {
reduce_buffer[offset] = value;
}

mark_group_finished(pos, is_last_group_done);
bool is_last_group_done = mark_group_finished(pos);

if (is_last_group_done[0]) {
if (is_last_group_done) {
value = ident;
if (config.should_group_x_reduce()) {
index_t input_offset =
Expand Down Expand Up @@ -1039,21 +1022,34 @@ static void launch_reduce_kernel(
const ReduceConfig& config,
const R& reduction) {
auto& queue = getCurrentSYCLQueue();
sycl::range<1> slm_sz{static_cast<uint32_t>(config.slm_sz())};
int shared_memory = config.slm_sz();
;
switch (config.output_vec_size) {
case 4: {
auto kfn = ReduceKernel<4, R>(reduction, slm_sz);
sycl_kernel_submit(config.global_sz(), config.group_sz(), queue, kfn);
sycl_kernel_submit<reduce_kernel<4, R>>(
config.global_sz(),
config.group_sz(),
queue,
shared_memory,
reduction);
break;
}
case 2: {
auto kfn = ReduceKernel<2, R>(reduction, slm_sz);
sycl_kernel_submit(config.global_sz(), config.group_sz(), queue, kfn);
sycl_kernel_submit<reduce_kernel<2, R>>(
config.global_sz(),
config.group_sz(),
queue,
shared_memory,
reduction);
break;
}
default: {
auto kfn = ReduceKernel<1, R>(reduction, slm_sz);
sycl_kernel_submit(config.global_sz(), config.group_sz(), queue, kfn);
sycl_kernel_submit<reduce_kernel<1, R>>(
config.global_sz(),
config.group_sz(),
queue,
shared_memory,
reduction);
break;
}
}
Expand Down Expand Up @@ -1297,15 +1293,15 @@ inline void gpu_reduce_kernel(
using R = ReduceOp<scalar_t, ops_t, uint32_t, out_scalar_t, vt0>;
switch (config.output_vec_size) {
case 4: {
config.set_group_dimension<scalar_t, ReduceKernel<4, R>>(dim0, dim1);
config.set_group_dimension<scalar_t, reduce_kernel<4, R>>(dim0, dim1);
break;
}
case 2: {
config.set_group_dimension<scalar_t, ReduceKernel<2, R>>(dim0, dim1);
config.set_group_dimension<scalar_t, reduce_kernel<2, R>>(dim0, dim1);
break;
}
default: {
config.set_group_dimension<scalar_t, ReduceKernel<1, R>>(dim0, dim1);
config.set_group_dimension<scalar_t, reduce_kernel<1, R>>(dim0, dim1);
break;
}
}
Expand Down
24 changes: 16 additions & 8 deletions src/ATen/native/xpu/sycl/WeightNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ struct WeightNormReduceKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
value *= value;
}

char* shared_ptr = reinterpret_cast<char*>(
shared_.template get_multi_ptr<sycl::access::decorated::no>().get());
if (cfg_.problem_along_x_) {
value = group_x_reduce(
item, shared_, vec_t(value), ReduceAdd<accscalar_t>())[0];
item, shared_ptr, vec_t(value), ReduceAdd<accscalar_t>())[0];
} else {
value = group_y_reduce(
item, shared_, vec_t(value), ReduceAdd<accscalar_t>())[0];
item, shared_ptr, vec_t(value), ReduceAdd<accscalar_t>())[0];
}

if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) {
Expand Down Expand Up @@ -289,12 +291,14 @@ struct WeightNormKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
}
}

char* shared_ptr = reinterpret_cast<char*>(
shared_.template get_multi_ptr<sycl::access::decorated::no>().get());
if (cfg_.problem_along_x_) {
value = group_x_reduce(
item, shared_, vec_t(value), ReduceAdd<accscalar_t>())[0];
item, shared_ptr, vec_t(value), ReduceAdd<accscalar_t>())[0];
} else {
value = group_y_reduce(
item, shared_, vec_t(value), ReduceAdd<accscalar_t>())[0];
item, shared_ptr, vec_t(value), ReduceAdd<accscalar_t>())[0];
}

int n_slid = (int)id.glb_batch % batch_wg_range_;
Expand Down Expand Up @@ -500,12 +504,14 @@ struct WeightNormBackwardReduceKernelFunctor
}
}

char* shared_ptr = reinterpret_cast<char*>(
shared_.template get_multi_ptr<sycl::access::decorated::no>().get());
if (cfg_.problem_along_x_) {
value = group_x_reduce(
item, shared_, vec_t(value), ReduceAdd<accscalar_t>())[0];
item, shared_ptr, vec_t(value), ReduceAdd<accscalar_t>())[0];
} else {
value = group_y_reduce(
item, shared_, vec_t(value), ReduceAdd<accscalar_t>())[0];
item, shared_ptr, vec_t(value), ReduceAdd<accscalar_t>())[0];
}

if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) {
Expand Down Expand Up @@ -813,12 +819,14 @@ struct WeightNormBackwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
}
}

char* shared_ptr = reinterpret_cast<char*>(
shared_.template get_multi_ptr<sycl::access::decorated::no>().get());
if (cfg_.problem_along_x_) {
value = group_x_reduce(
item, shared_, vec_t(value), ReduceAdd<accscalar_t>())[0];
item, shared_ptr, vec_t(value), ReduceAdd<accscalar_t>())[0];
} else {
value = group_y_reduce(
item, shared_, vec_t(value), ReduceAdd<accscalar_t>())[0];
item, shared_ptr, vec_t(value), ReduceAdd<accscalar_t>())[0];
}

int n_slid = (int)id.glb_batch % batch_wg_range_;
Expand Down
18 changes: 18 additions & 0 deletions src/comm/DeviceProperties.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
#include <ATen/xpu/XPUContext.h>

#include <comm/Runtime.h>
#include <sycl/sycl.hpp>
#include <iostream>

namespace syclext = sycl::ext::oneapi;
namespace syclexp = sycl::ext::oneapi::experimental;

namespace xpu {
namespace sycl {

Expand Down Expand Up @@ -35,6 +39,20 @@ static int64_t syclMaxWorkGroupSize(
return syclMaxWorkGroupSize<KernelClass>(dev_id);
}

// For SYCL free function
template <auto* kptr>
static int64_t syclMaxWorkGroupSize(
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
auto q = c10::xpu::getCurrentXPUStream(dev_id).queue();
auto ctxt = q.get_context();
auto dev = q.get_device();
auto exe_bndl =
::syclexp::get_kernel_bundle<kptr, ::sycl::bundle_state::executable>(
ctxt);
::sycl::kernel k = exe_bndl.template ext_oneapi_get_kernel<kptr>();
return k.get_info<::sycl::info::kernel_device_specific::work_group_size>(dev);
}

static inline int64_t syclDeviceMaxWorkGroupSize(
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
Expand Down
40 changes: 37 additions & 3 deletions src/comm/SYCLHelpers.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <comm/Scalar.h>
#include <sycl/ext/oneapi/work_group_static.hpp> // Remove it once the header is exposed by sycl.hpp
#include <sycl/sycl.hpp>

namespace syclext = sycl::ext::oneapi;
Expand Down Expand Up @@ -148,14 +149,47 @@ static inline void sycl_kernel_submit(
int64_t global_range,
int64_t local_range,
::sycl::queue q,
int slm_sz,
Kargs... args) {
sycl::context ctxt = q.get_context();
auto exe_bndl =
syclexp::get_kernel_bundle<kptr, sycl::bundle_state::executable>(ctxt);
sycl::kernel ker = exe_bndl.template ext_oneapi_get_kernel<kptr>();
syclexp::launch_config cfg{::sycl::nd_range<1>(
::sycl::range<1>(global_range), ::sycl::range<1>(local_range))};
syclexp::nd_launch(q, cfg, ker, args...);
if (slm_sz != 0) {
syclexp::launch_config cfg{
::sycl::nd_range<1>(
::sycl::range<1>(global_range), ::sycl::range<1>(local_range)),
syclexp::properties{syclexp::work_group_scratch_size(slm_sz)}};
syclexp::nd_launch(q, cfg, ker, args...);
} else {
syclexp::launch_config cfg{::sycl::nd_range<1>(
::sycl::range<1>(global_range), ::sycl::range<1>(local_range))};
syclexp::nd_launch(q, cfg, ker, args...);
}
}

template <auto* kptr, int dim, typename... Kargs>
static inline void sycl_kernel_submit(
::sycl::range<dim> global_range,
::sycl::range<dim> local_range,
::sycl::queue q,
int slm_sz,
Kargs... args) {
sycl::context ctxt = q.get_context();
auto exe_bndl =
syclexp::get_kernel_bundle<kptr, sycl::bundle_state::executable>(ctxt);
sycl::kernel ker = exe_bndl.template ext_oneapi_get_kernel<kptr>();
if (slm_sz != 0) {
syclexp::launch_config cfg{
::sycl::nd_range<dim>(
::sycl::range<dim>(global_range), ::sycl::range<dim>(local_range)),
syclexp::properties{syclexp::work_group_scratch_size(slm_sz)}};
syclexp::nd_launch(q, cfg, ker, args...);
} else {
syclexp::launch_config cfg{::sycl::nd_range<dim>(
::sycl::range<dim>(global_range), ::sycl::range<dim>(local_range))};
syclexp::nd_launch(q, cfg, ker, args...);
}
}

#define SYCL_KERNEL_STRING(var, str) \
Expand Down