@@ -921,39 +921,43 @@ class group : public group_base<dimensions> {
921921} // namespace experimental
922922
923923// Calculate the number of work-groups per compute unit
924- // \param [in] kernel SYCL kernel to calculate for
924+ // \tparam [in] KernelName SYCL kernel name to calculate for
925925// \param [in] q SYCL queue used to execute kernel
926926// \param [in] wg_dim3 dim3 representing work-group shape
927927// \param [in] local_mem_size Local memory usage per work-group in bytes
928928// \return size_t representing maximum work-groups per compute unit
929929template <class KernelName >
930- size_t max_active_work_groups_per_cu (KernelName kernel, sycl::queue q,
931- syclcompat::dim3 wg_dim3,
932- size_t local_mem_size ) {
930+ size_t max_active_work_groups_per_cu (
931+ syclcompat::dim3 wg_dim3, size_t local_mem_size ,
932+ sycl::queue queue = syclcompat::get_default_queue() ) {
933933 namespace syclex = sycl::ext::oneapi::experimental;
934934 // max_num_work_groups only supports range<3>
935+ auto ctx = queue.get_context ();
936+ auto bundle = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx);
937+ auto kernel = bundle.template get_kernel <KernelName>();
935938 sycl::range<3 > wg_range_3d (wg_dim3);
936939 size_t max_wgs = kernel.template ext_oneapi_get_info <
937- syclex::info::kernel_queue_specific::max_num_work_groups>(q , wg_range_3d,
940+ syclex::info::kernel_queue_specific::max_num_work_groups>(queue , wg_range_3d,
938941 local_mem_size);
939942 size_t max_compute_units =
940- q .get_device ().get_info <sycl::info::device::max_compute_units>();
943+ queue .get_device ().get_info <sycl::info::device::max_compute_units>();
941944 // Spec dictates max_compute_units > 0, so no need to catch div 0
942945 return max_wgs / max_compute_units;
943946}
944947
945948// Calculate the number of work-groups per compute unit
946- // \param [in] kernel SYCL kernel to calculate for
949+ // \tparam [in] KernelName SYCL kernel name to calculate for
950+ // \tparam [in] RangeDim the dimension of the sycl::range
947951// \param [in] q SYCL queue used to execute kernel
948952// \param [in] wg_range SYCL work-group range
949953// \param [in] local_mem_size Local memory usage per work-group in bytes
950954// \return size_t representing maximum work-groups per compute unit
951955template <class KernelName , int RangeDim>
952- size_t max_active_work_groups_per_cu (KernelName kernel, sycl::queue q,
953- sycl::range<RangeDim> wg_range,
954- size_t local_mem_size ) {
955- return max_active_work_groups_per_cu (kernel, q, syclcompat::dim3 (wg_range),
956- local_mem_size);
956+ size_t max_active_work_groups_per_cu (
957+ sycl::range<RangeDim> wg_range, size_t local_mem_size ,
958+ sycl::queue queue = syclcompat::get_default_queue() ) {
959+ return max_active_work_groups_per_cu<KernelName>( syclcompat::dim3 (wg_range),
960+ local_mem_size, queue );
957961}
958962
959963// / If x <= 2, then return a pointer to the default queue;
0 commit comments