Skip to content

Commit 0a44c80

Browse files
committed
Adopt templatized approach to avoid having to grab kernel_bundle
Also reorder params to match the rest of syclcompat
1 parent a79a9dd commit 0a44c80

File tree

3 files changed

+36
-34
lines changed

3 files changed

+36
-34
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,23 +1574,23 @@ public:
15741574
```
15751575
15761576
SYCLcompat provides a wrapper API `max_active_work_groups_per_cu` providing
1577-
'work-groups per compute unit' semantics. It takes a `sycl::kernel` object, a
1578-
`sycl::queue`, a work-groups size represented by either `sycl::range<Dim>` or
1579-
`syclcompat::dim3`, and the local memory size in bytes. The function returns
1580-
the maximum number of work-groups which can be executed per compute unit. May
1581-
return *zero* even when below resource limits (i.e. returning `0` does not
1582-
imply the kernel cannot execute).
1577+
'work-groups per compute unit' semantics. It is templated on the kernel
1578+
functor, and takes a `sycl::queue`, a work-groups size represented by either
1579+
`sycl::range<Dim>` or `syclcompat::dim3`, and the local memory size in bytes.
1580+
The function returns the maximum number of work-groups which can be executed
1581+
per compute unit. May return *zero* even when below resource limits (i.e.
1582+
returning `0` does not imply the kernel cannot execute).
15831583
```cpp
15841584
namespace syclcompat{
15851585
template <class KernelName>
1586-
size_t max_active_work_groups_per_cu(KernelName kernel, sycl::queue q,
1587-
syclcompat::dim3 wg_dim3,
1588-
size_t local_mem_size);
1586+
size_t max_active_work_groups_per_cu(
1587+
syclcompat::dim3 wg_dim3, size_t local_mem_size,
1588+
sycl::queue queue = syclcompat::get_default_queue());
15891589
15901590
template <class KernelName, int RangeDim>
1591-
size_t max_active_work_groups_per_cu(KernelName kernel, sycl::queue q,
1592-
sycl::range<RangeDim> wg_range,
1593-
size_t local_mem_size);
1591+
size_t max_active_work_groups_per_cu(
1592+
sycl::range<RangeDim> wg_range, size_t local_mem_size,
1593+
sycl::queue queue = syclcompat::get_default_queue());
15941594
}
15951595
```
15961596

sycl/include/syclcompat/util.hpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -921,39 +921,43 @@ class group : public group_base<dimensions> {
921921
} // namespace experimental
922922

923923
// Calculate the number of work-groups per compute unit
924-
// \param [in] kernel SYCL kernel to calculate for
924+
// \tparam [in] KernelName SYCL kernel name to calculate for
925925
// \param [in] q SYCL queue used to execute kernel
926926
// \param [in] wg_dim3 dim3 representing work-group shape
927927
// \param [in] local_mem_size Local memory usage per work-group in bytes
928928
// \return size_t representing maximum work-groups per compute unit
929929
template <class KernelName>
930-
size_t max_active_work_groups_per_cu(KernelName kernel, sycl::queue q,
931-
syclcompat::dim3 wg_dim3,
932-
size_t local_mem_size) {
930+
size_t max_active_work_groups_per_cu(
931+
syclcompat::dim3 wg_dim3, size_t local_mem_size,
932+
sycl::queue queue = syclcompat::get_default_queue()) {
933933
namespace syclex = sycl::ext::oneapi::experimental;
934934
// max_num_work_groups only supports range<3>
935+
auto ctx = queue.get_context();
936+
auto bundle = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx);
937+
auto kernel = bundle.template get_kernel<KernelName>();
935938
sycl::range<3> wg_range_3d(wg_dim3);
936939
size_t max_wgs = kernel.template ext_oneapi_get_info<
937-
syclex::info::kernel_queue_specific::max_num_work_groups>(q, wg_range_3d,
940+
syclex::info::kernel_queue_specific::max_num_work_groups>(queue, wg_range_3d,
938941
local_mem_size);
939942
size_t max_compute_units =
940-
q.get_device().get_info<sycl::info::device::max_compute_units>();
943+
queue.get_device().get_info<sycl::info::device::max_compute_units>();
941944
// Spec dictates max_compute_units > 0, so no need to catch div 0
942945
return max_wgs / max_compute_units;
943946
}
944947

945948
// Calculate the number of work-groups per compute unit
946-
// \param [in] kernel SYCL kernel to calculate for
949+
// \tparam [in] KernelName SYCL kernel name to calculate for
950+
// \tparam [in] RangeDim the dimension of the sycl::range
947951
// \param [in] q SYCL queue used to execute kernel
948952
// \param [in] wg_range SYCL work-group range
949953
// \param [in] local_mem_size Local memory usage per work-group in bytes
950954
// \return size_t representing maximum work-groups per compute unit
951955
template <class KernelName, int RangeDim>
952-
size_t max_active_work_groups_per_cu(KernelName kernel, sycl::queue q,
953-
sycl::range<RangeDim> wg_range,
954-
size_t local_mem_size) {
955-
return max_active_work_groups_per_cu(kernel, q, syclcompat::dim3(wg_range),
956-
local_mem_size);
956+
size_t max_active_work_groups_per_cu(
957+
sycl::range<RangeDim> wg_range, size_t local_mem_size,
958+
sycl::queue queue = syclcompat::get_default_queue()) {
959+
return max_active_work_groups_per_cu<KernelName>(syclcompat::dim3(wg_range),
960+
local_mem_size, queue);
957961
}
958962

959963
/// If x <= 2, then return a pointer to the default queue;

sycl/test-e2e/syclcompat/util/max_active_work_groups_per_cu.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
/***************************************************************************
32
*
43
* Copyright (C) Codeplay Software Ltd.
@@ -66,23 +65,22 @@ void test_max_active_work_groups_per_cu(sycl::queue q,
6665
if constexpr (!KernelName<RangeDim>::has_local_mem)
6766
assert(local_mem_size == 0 && "Bad test setup");
6867

69-
auto ctx = q.get_context();
70-
auto bundle = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx);
71-
auto kernel = bundle.template get_kernel<KernelName<RangeDim>>();
72-
73-
size_t max_per_cu = syclcompat::max_active_work_groups_per_cu(
74-
kernel, q, wg_range, local_mem_size);
75-
68+
size_t max_per_cu = syclcompat::max_active_work_groups_per_cu<KernelName<RangeDim>>(
69+
wg_range, local_mem_size, q);
70+
7671
// Check we get the same result passing equivalent dim3
7772
syclcompat::dim3 wg_dim3{wg_range};
78-
size_t max_per_cu_dim3 = syclcompat::max_active_work_groups_per_cu(
79-
kernel, q, wg_dim3, local_mem_size);
73+
size_t max_per_cu_dim3 = syclcompat::max_active_work_groups_per_cu<KernelName<RangeDim>>(
74+
wg_dim3, local_mem_size, q);
8075
assert(max_per_cu == max_per_cu_dim3);
8176

8277
// Compare w/ reference impl
8378
size_t max_compute_units =
8479
q.get_device().get_info<sycl::info::device::max_compute_units>();
8580
namespace syclex = sycl::ext::oneapi::experimental;
81+
auto ctx = q.get_context();
82+
auto bundle = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx);
83+
auto kernel = bundle.template get_kernel<KernelName<RangeDim>>();
8684
size_t max_wgs = kernel.template ext_oneapi_get_info<
8785
syclex::info::kernel_queue_specific::max_num_work_groups>(
8886
q, sycl::range<3>{syclcompat::dim3{wg_range}}, local_mem_size);

0 commit comments

Comments
 (0)