diff --git a/src/ATen/native/xpu/sycl/Loops.h b/src/ATen/native/xpu/sycl/Loops.h index 728835b29..55a3d797c 100644 --- a/src/ATen/native/xpu/sycl/Loops.h +++ b/src/ATen/native/xpu/sycl/Loops.h @@ -395,6 +395,7 @@ static inline void launch_vectorized_kernel( wg_sz * num_wg, \ wg_sz, \ getCurrentSYCLQueue(), \ + 0, \ N, \ f, \ data, \ diff --git a/src/ATen/native/xpu/sycl/Reduce.h b/src/ATen/native/xpu/sycl/Reduce.h index 05c804202..6900e73af 100644 --- a/src/ATen/native/xpu/sycl/Reduce.h +++ b/src/ATen/native/xpu/sycl/Reduce.h @@ -32,11 +32,11 @@ template inline at::detail::Array group_reduce( item_t item, int wg_size, - sycl_local_ptr shared, + char* shared, at::detail::Array value, CombineFunc combine) { using vec_t = at::detail::Array; - sycl_local_ptr shared_(shared); + vec_t* shared_ = reinterpret_cast(shared_); int l_x = item.get_local_linear_id(); // int dim_x = wg_size; auto sg = item.get_sub_group(); @@ -101,11 +101,11 @@ inline at::detail::Array group_reduce( template inline at::detail::Array group_x_reduce( item_t item, - sycl_local_ptr shared, + char* shared, at::detail::Array value, CombineFunc combine) { using vec_t = at::detail::Array; - sycl_local_ptr shared_(shared); + vec_t* shared_ = reinterpret_cast(shared_); int l_x = item.get_local_id(1), l_y = item.get_local_id(0); int g_x = item.get_local_range(1); int dim_x = g_x; @@ -143,11 +143,11 @@ inline at::detail::Array group_x_reduce( template inline at::detail::Array group_y_reduce( item_t item, - sycl_local_ptr shared, + char* shared, at::detail::Array value, CombineFunc combine) { using vec_t = at::detail::Array; - sycl_local_ptr shared_(shared); + vec_t* shared_ = reinterpret_cast(shared_); int l_x = item.get_local_id(1), l_y = item.get_local_id(0); int g_x = item.get_local_range(1); int dim_y = item.get_local_range(0); @@ -238,9 +238,9 @@ struct ReduceConfig { int input_vec_size = 1; int output_vec_size = 1; - template + template void set_group_dimension(int64_t dim0, int64_t dim1) { - auto max_wg_sz = syclMaxWorkGroupSize(); + auto max_wg_sz = syclMaxWorkGroupSize(); // Bypass reduction on SLM by sparing workload to other SGs. As the // result, reduction of small shape input only requires some shift // operations in side of SG. It is functional WA. We got case failures on @@ -377,26 +377,10 @@ struct ReduceConfig { std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); template -class ReduceKernel : public __SYCL_KER_CONFIG_CONVENTION__ { - public: - ReduceKernel(R reduction, sycl::range<1> slm_sz) - : reduction_(reduction), slm_sz_(slm_sz), shared_(), finished_() {} - - void operator()(sycl::nd_item<2> pos) const { - reduction_.template run(pos, shared_, finished_); - } - - void sycl_ker_config_convention(sycl::handler& cgh) { - shared_ = sycl_local_acc_t(slm_sz_, cgh); - finished_ = sycl_local_acc_t({1}, cgh); - } - - private: - R reduction_; - sycl::range<1> slm_sz_; - sycl_local_acc_t shared_; /* group tree reduce */ - sycl_local_acc_t finished_; /* last WG flag to broadcast inner WG */ -}; +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<2>)) +void reduce_kernel(R reduction) { + reduction.template run(); +} template static OffsetCalculator<2, index_t> make_output_calculator( @@ -528,10 +512,9 @@ struct ReduceOp { } template - void run( - sycl::nd_item<2> pos, - sycl_local_ptr shared, - sycl_local_ptr finished) const { + void run() const { + auto pos = syclext::this_work_item::get_nd_item<2>(); + char* shared = (char*)syclexp::get_work_group_scratch_memory(); index_t output_idx = config.output_idx(pos); index_t input_idx = config.input_idx(pos); auto base_offsets1 = output_calc.get(output_idx)[1]; @@ -598,7 +581,7 @@ struct ReduceOp { } if (config.should_global_reduce()) { - value = global_reduce(pos, value, acc, shared, finished); + value = global_reduce(pos, value, acc, shared); } else if (config.should_store(pos, output_idx)) { if (accumulate) { #pragma unroll @@ -823,8 +806,8 @@ struct ReduceOp { } // In/out from slm pointers - void mark_group_finished(sycl::nd_item<2> pos, sycl_local_ptr finished) - const { + bool mark_group_finished(sycl::nd_item<2> pos) const { + syclexp::work_group_static finished; pos.barrier(sycl_local_fence); if (pos.get_local_linear_id() == 0) { @@ -832,9 +815,10 @@ struct ReduceOp { int prev_groups_finished = count.fetch_add( 1, sycl_mem_odr_acq_rel /* , default memory scope is device */); - finished[0] = (prev_groups_finished == (int)(pos.get_group_range(0) - 1)); + finished = (prev_groups_finished == (int)(pos.get_group_range(0) - 1)); } pos.barrier(sycl_local_fence); + return finished; } template @@ -922,8 +906,7 @@ struct ReduceOp { sycl::nd_item<2> pos, at::detail::Array value, at::detail::Array* acc, - sycl_local_ptr shared_memory, - sycl_local_ptr is_last_group_done) const { + char* shared_memory) const { using arg_vec_t = at::detail::Array; using out_ptr_vec_t = at::detail::Array; using offset_vec_t = at::detail::Array; @@ -945,9 +928,9 @@ struct ReduceOp { reduce_buffer[offset] = value; } - mark_group_finished(pos, is_last_group_done); + bool is_last_group_done = mark_group_finished(pos); - if (is_last_group_done[0]) { + if (is_last_group_done) { value = ident; if (config.should_group_x_reduce()) { index_t input_offset = @@ -1039,21 +1022,34 @@ static void launch_reduce_kernel( const ReduceConfig& config, const R& reduction) { auto& queue = getCurrentSYCLQueue(); - sycl::range<1> slm_sz{static_cast(config.slm_sz())}; + int shared_memory = config.slm_sz(); + ; switch (config.output_vec_size) { case 4: { - auto kfn = ReduceKernel<4, R>(reduction, slm_sz); - sycl_kernel_submit(config.global_sz(), config.group_sz(), queue, kfn); + sycl_kernel_submit>( + config.global_sz(), + config.group_sz(), + queue, + shared_memory, + reduction); break; } case 2: { - auto kfn = ReduceKernel<2, R>(reduction, slm_sz); - sycl_kernel_submit(config.global_sz(), config.group_sz(), queue, kfn); + sycl_kernel_submit>( + config.global_sz(), + config.group_sz(), + queue, + shared_memory, + reduction); break; } default: { - auto kfn = ReduceKernel<1, R>(reduction, slm_sz); - sycl_kernel_submit(config.global_sz(), config.group_sz(), queue, kfn); + sycl_kernel_submit>( + config.global_sz(), + config.group_sz(), + queue, + shared_memory, + reduction); break; } } @@ -1297,15 +1293,15 @@ inline void gpu_reduce_kernel( using R = ReduceOp; switch (config.output_vec_size) { case 4: { - config.set_group_dimension>(dim0, dim1); + config.set_group_dimension>(dim0, dim1); break; } case 2: { - config.set_group_dimension>(dim0, dim1); + config.set_group_dimension>(dim0, dim1); break; } default: { - config.set_group_dimension>(dim0, dim1); + config.set_group_dimension>(dim0, dim1); break; } } diff --git a/src/ATen/native/xpu/sycl/WeightNormKernels.cpp b/src/ATen/native/xpu/sycl/WeightNormKernels.cpp index ac67d5d34..2de1d32aa 100644 --- a/src/ATen/native/xpu/sycl/WeightNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/WeightNormKernels.cpp @@ -53,12 +53,14 @@ struct WeightNormReduceKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { value *= value; } + char* shared_ptr = reinterpret_cast( + shared_.template get_multi_ptr().get()); if (cfg_.problem_along_x_) { value = group_x_reduce( - item, shared_, vec_t(value), ReduceAdd())[0]; + item, shared_ptr, vec_t(value), ReduceAdd())[0]; } else { value = group_y_reduce( - item, shared_, vec_t(value), ReduceAdd())[0]; + item, shared_ptr, vec_t(value), ReduceAdd())[0]; } if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { @@ -289,12 +291,14 @@ struct WeightNormKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } } + char* shared_ptr = reinterpret_cast( + shared_.template get_multi_ptr().get()); if (cfg_.problem_along_x_) { value = group_x_reduce( - item, shared_, vec_t(value), ReduceAdd())[0]; + item, shared_ptr, vec_t(value), ReduceAdd())[0]; } else { value = group_y_reduce( - item, shared_, vec_t(value), ReduceAdd())[0]; + item, shared_ptr, vec_t(value), ReduceAdd())[0]; } int n_slid = (int)id.glb_batch % batch_wg_range_; @@ -500,12 +504,14 @@ struct WeightNormBackwardReduceKernelFunctor } } + char* shared_ptr = reinterpret_cast( + shared_.template get_multi_ptr().get()); if (cfg_.problem_along_x_) { value = group_x_reduce( - item, shared_, vec_t(value), ReduceAdd())[0]; + item, shared_ptr, vec_t(value), ReduceAdd())[0]; } else { value = group_y_reduce( - item, shared_, vec_t(value), ReduceAdd())[0]; + item, shared_ptr, vec_t(value), ReduceAdd())[0]; } if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { @@ -813,12 +819,14 @@ struct WeightNormBackwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } } + char* shared_ptr = reinterpret_cast( + shared_.template get_multi_ptr().get()); if (cfg_.problem_along_x_) { value = group_x_reduce( - item, shared_, vec_t(value), ReduceAdd())[0]; + item, shared_ptr, vec_t(value), ReduceAdd())[0]; } else { value = group_y_reduce( - item, shared_, vec_t(value), ReduceAdd())[0]; + item, shared_ptr, vec_t(value), ReduceAdd())[0]; } int n_slid = (int)id.glb_batch % batch_wg_range_; diff --git a/src/comm/DeviceProperties.h b/src/comm/DeviceProperties.h index ee0d285ea..1620ea1b7 100644 --- a/src/comm/DeviceProperties.h +++ b/src/comm/DeviceProperties.h @@ -3,8 +3,12 @@ #include #include +#include #include +namespace syclext = sycl::ext::oneapi; +namespace syclexp = sycl::ext::oneapi::experimental; + namespace xpu { namespace sycl { @@ -35,6 +39,20 @@ static int64_t syclMaxWorkGroupSize( return syclMaxWorkGroupSize(dev_id); } +// For SYCL free function +template +static int64_t syclMaxWorkGroupSize( + at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + auto q = c10::xpu::getCurrentXPUStream(dev_id).queue(); + auto ctxt = q.get_context(); + auto dev = q.get_device(); + auto exe_bndl = + ::syclexp::get_kernel_bundle( + ctxt); + ::sycl::kernel k = exe_bndl.template ext_oneapi_get_kernel(); + return k.get_info<::sycl::info::kernel_device_specific::work_group_size>(dev); +} + static inline int64_t syclDeviceMaxWorkGroupSize( at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); diff --git a/src/comm/SYCLHelpers.h b/src/comm/SYCLHelpers.h index 34d33698c..a05a31da9 100644 --- a/src/comm/SYCLHelpers.h +++ b/src/comm/SYCLHelpers.h @@ -1,6 +1,7 @@ #pragma once #include +#include // Remove it once the header is exposed by sycl.hpp #include namespace syclext = sycl::ext::oneapi; @@ -148,14 +149,47 @@ static inline void sycl_kernel_submit( int64_t global_range, int64_t local_range, ::sycl::queue q, + int slm_sz, Kargs... args) { sycl::context ctxt = q.get_context(); auto exe_bndl = syclexp::get_kernel_bundle(ctxt); sycl::kernel ker = exe_bndl.template ext_oneapi_get_kernel(); - syclexp::launch_config cfg{::sycl::nd_range<1>( - ::sycl::range<1>(global_range), ::sycl::range<1>(local_range))}; - syclexp::nd_launch(q, cfg, ker, args...); + if (slm_sz != 0) { + syclexp::launch_config cfg{ + ::sycl::nd_range<1>( + ::sycl::range<1>(global_range), ::sycl::range<1>(local_range)), + syclexp::properties{syclexp::work_group_scratch_size(slm_sz)}}; + syclexp::nd_launch(q, cfg, ker, args...); + } else { + syclexp::launch_config cfg{::sycl::nd_range<1>( + ::sycl::range<1>(global_range), ::sycl::range<1>(local_range))}; + syclexp::nd_launch(q, cfg, ker, args...); + } +} + +template +static inline void sycl_kernel_submit( + ::sycl::range global_range, + ::sycl::range local_range, + ::sycl::queue q, + int slm_sz, + Kargs... args) { + sycl::context ctxt = q.get_context(); + auto exe_bndl = + syclexp::get_kernel_bundle(ctxt); + sycl::kernel ker = exe_bndl.template ext_oneapi_get_kernel(); + if (slm_sz != 0) { + syclexp::launch_config cfg{ + ::sycl::nd_range( + ::sycl::range(global_range), ::sycl::range(local_range)), + syclexp::properties{syclexp::work_group_scratch_size(slm_sz)}}; + syclexp::nd_launch(q, cfg, ker, args...); + } else { + syclexp::launch_config cfg{::sycl::nd_range( + ::sycl::range(global_range), ::sycl::range(local_range))}; + syclexp::nd_launch(q, cfg, ker, args...); + } } #define SYCL_KERNEL_STRING(var, str) \