diff --git a/sycl/doc/syclcompat/README.md b/sycl/doc/syclcompat/README.md index 1ba7b75dbcd64..d576801a8c661 100644 --- a/sycl/doc/syclcompat/README.md +++ b/sycl/doc/syclcompat/README.md @@ -1573,6 +1573,27 @@ public: } // namespace syclcompat ``` +SYCLcompat provides a wrapper API `max_active_work_groups_per_cu` providing +'work-groups per compute unit' semantics. It is templated on the kernel +functor, and takes a work-group size represented by either `sycl::range` +or `syclcompat::dim3`, the local memory size in bytes, and an optional queue. +The function returns the maximum number of work-groups which can be executed +per compute unit. May return *zero* even when below resource limits (i.e. +returning `0` does not imply the kernel cannot execute). +```cpp +namespace syclcompat{ +template +size_t max_active_work_groups_per_cu( + syclcompat::dim3 wg_dim3, size_t local_mem_size, + sycl::queue queue = syclcompat::get_default_queue()); + +template +size_t max_active_work_groups_per_cu( + sycl::range wg_range, size_t local_mem_size, + sycl::queue queue = syclcompat::get_default_queue()); +} +``` + To assist machine translation, helper aliases are provided for inlining and alignment attributes. The class template declarations `sycl_compat_kernel_name` and `sycl_compat_kernel_scalar` are used to assist automatic generation of diff --git a/sycl/include/syclcompat/util.hpp b/sycl/include/syclcompat/util.hpp index c7a53522c8e51..52df2ded36506 100644 --- a/sycl/include/syclcompat/util.hpp +++ b/sycl/include/syclcompat/util.hpp @@ -39,6 +39,7 @@ #include #include +#include #if defined(__NVPTX__) #include @@ -919,6 +920,46 @@ class group : public group_base { }; } // namespace experimental +// Calculate the number of work-groups per compute unit +// \tparam [in] KernelName SYCL kernel name to calculate for +// \param [in] q SYCL queue used to execute kernel +// \param [in] wg_dim3 dim3 representing work-group shape +// \param [in] local_mem_size Local memory usage per work-group in bytes +// \return size_t representing maximum work-groups per compute unit +template +size_t max_active_work_groups_per_cu( + syclcompat::dim3 wg_dim3, size_t local_mem_size, + sycl::queue queue = syclcompat::get_default_queue()) { + namespace syclex = sycl::ext::oneapi::experimental; + // max_num_work_groups only supports range<3> + auto ctx = queue.get_context(); + auto bundle = sycl::get_kernel_bundle(ctx); + auto kernel = bundle.template get_kernel(); + sycl::range<3> wg_range_3d(wg_dim3); + size_t max_wgs = kernel.template ext_oneapi_get_info< + syclex::info::kernel_queue_specific::max_num_work_groups>(queue, wg_range_3d, + local_mem_size); + size_t max_compute_units = + queue.get_device().get_info(); + // Spec dictates max_compute_units > 0, so no need to catch div 0 + return max_wgs / max_compute_units; +} + +// Calculate the number of work-groups per compute unit +// \tparam [in] KernelName SYCL kernel name to calculate for +// \tparam [in] RangeDim the dimension of the sycl::range +// \param [in] q SYCL queue used to execute kernel +// \param [in] wg_range SYCL work-group range +// \param [in] local_mem_size Local memory usage per work-group in bytes +// \return size_t representing maximum work-groups per compute unit +template +size_t max_active_work_groups_per_cu( + sycl::range wg_range, size_t local_mem_size, + sycl::queue queue = syclcompat::get_default_queue()) { + return max_active_work_groups_per_cu(syclcompat::dim3(wg_range), + local_mem_size, queue); +} + /// If x <= 2, then return a pointer to the default queue; /// otherwise, return x reinterpreted as a queue_ptr. inline queue_ptr int_as_queue_ptr(uintptr_t x) { diff --git a/sycl/test-e2e/syclcompat/util/max_active_work_groups_per_cu.cpp b/sycl/test-e2e/syclcompat/util/max_active_work_groups_per_cu.cpp new file mode 100644 index 0000000000000..78c2a00756332 --- /dev/null +++ b/sycl/test-e2e/syclcompat/util/max_active_work_groups_per_cu.cpp @@ -0,0 +1,133 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCLcompat + * + * max_active_work_groups_per_cu.cpp + * + * Description: + * Test the syclcompat::max_active_work_groups_per_cu API + **************************************************************************/ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +#include "sycl/accessor.hpp" +#include +#include + +template +using sycl_global_accessor = + sycl::accessor; + +using value_type = int; + +template struct MyKernel { + MyKernel(sycl_global_accessor acc) : acc_{acc} {} + void operator()(sycl::nd_item item) const { + auto gid = item.get_global_id(); + acc_[gid] = item.get_global_linear_id(); + } + sycl_global_accessor acc_; + static constexpr bool has_local_mem = false; +}; + +template struct MyLocalMemKernel { + MyLocalMemKernel(sycl_global_accessor acc, + sycl::local_accessor lacc) + : acc_{acc}, lacc_{lacc} {} + void operator()(sycl::nd_item item) const { + auto gid = item.get_global_id(); + acc_[gid] = item.get_global_linear_id(); + auto lid = item.get_local_id(); + lacc_[lid] = item.get_global_linear_id(); + } + sycl_global_accessor acc_; + sycl::local_accessor lacc_; + static constexpr bool has_local_mem = true; +}; + +template