@@ -122,6 +122,32 @@ class kernel_impl {
122122 template <typename Param>
123123 typename Param::return_type ext_oneapi_get_info (queue Queue) const ;
124124
125+ // / Query queue/launch-specific information from a kernel using the
126+ // / info::kernel_queue_specific descriptor for a specific Queue and values.
127+ // / max_num_work_groups is the only valid descriptor for this function.
128+ // /
129+ // / \param Queue is a valid SYCL queue.
130+ // / \param WorkGroupSize is the work-group size the number of work-groups is
131+ // / requested for.
132+ // / \return depends on information being queried.
133+ template <typename Param>
134+ typename Param::return_type
135+ ext_oneapi_get_info (queue Queue, const range<1 > &MaxWorkGroupSize,
136+ size_t DynamicLocalMemorySize) const ;
137+
138+ // / Query queue/launch-specific information from a kernel using the
139+ // / info::kernel_queue_specific descriptor for a specific Queue and values.
140+ // / max_num_work_groups is the only valid descriptor for this function.
141+ // /
142+ // / \param Queue is a valid SYCL queue.
143+ // / \param WorkGroupSize is the work-group size the number of work-groups is
144+ // / requested for.
145+ // / \return depends on information being queried.
146+ template <typename Param>
147+ typename Param::return_type
148+ ext_oneapi_get_info (queue Queue, const range<2 > &MaxWorkGroupSize,
149+ size_t DynamicLocalMemorySize) const ;
150+
125151 // / Query queue/launch-specific information from a kernel using the
126152 // / info::kernel_queue_specific descriptor for a specific Queue and values.
127153 // / max_num_work_groups is the only valid descriptor for this function.
@@ -192,11 +218,49 @@ class kernel_impl {
192218
193219 // / Check if the occupancy limits are exceeded for the given kernel launch
194220 // / configuration.
221+ template <int Dimensions>
195222 bool exceedsOccupancyResourceLimits (const device &Device,
196- const range<3 > &WorkGroupSize,
223+ const range<Dimensions > &WorkGroupSize,
197224 size_t DynamicLocalMemorySize) const ;
225+ template <int Dimensions>
226+ size_t queryMaxNumWorkGroups (queue Queue,
227+ const range<Dimensions> &WorkGroupSize,
228+ size_t DynamicLocalMemorySize) const ;
198229};
199230
231+ template <int Dimensions>
232+ bool kernel_impl::exceedsOccupancyResourceLimits (
233+ const device &Device, const range<Dimensions> &WorkGroupSize,
234+ size_t DynamicLocalMemorySize) const {
235+ // Respect occupancy limits for WorkGroupSize and DynamicLocalMemorySize.
236+ // Generally, exceeding hardware resource limits will yield in an error when
237+ // the kernel is launched.
238+ const size_t MaxWorkGroupSize =
239+ get_info<info::kernel_device_specific::work_group_size>(Device);
240+ const size_t MaxLocalMemorySizeInBytes =
241+ Device.get_info <info::device::local_mem_size>();
242+
243+ if (WorkGroupSize.size () > MaxWorkGroupSize)
244+ return true ;
245+
246+ if (DynamicLocalMemorySize > MaxLocalMemorySizeInBytes)
247+ return true ;
248+
249+ // It will be impossible to launch a kernel for Cuda when the hardware limit
250+ // for the 32-bit registers page file size is exceeded.
251+ if (Device.get_backend () == backend::ext_oneapi_cuda) {
252+ const uint32_t RegsPerWorkItem =
253+ get_info<info::kernel_device_specific::ext_codeplay_num_regs>(Device);
254+ const uint32_t MaxRegsPerWorkGroup =
255+ Device.get_info <ext::codeplay::experimental::info::device::
256+ max_registers_per_work_group>();
257+ if ((MaxWorkGroupSize * RegsPerWorkItem) > MaxRegsPerWorkGroup)
258+ return true ;
259+ }
260+
261+ return false ;
262+ }
263+
200264template <typename Param>
201265inline typename Param::return_type kernel_impl::get_info () const {
202266 static_assert (is_kernel_info_desc<Param>::value,
@@ -243,13 +307,11 @@ kernel_impl::get_info(const device &Device,
243307
244308namespace syclex = ext::oneapi::experimental;
245309
246- template <>
247- inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
248- return_type
249- kernel_impl::ext_oneapi_get_info<
250- syclex::info::kernel_queue_specific::max_num_work_groups>(
251- queue Queue, const range<3 > &WorkGroupSize,
252- size_t DynamicLocalMemorySize) const {
310+ template <int Dimensions>
311+ size_t
312+ kernel_impl::queryMaxNumWorkGroups (queue Queue,
313+ const range<Dimensions> &WorkGroupSize,
314+ size_t DynamicLocalMemorySize) const {
253315 if (WorkGroupSize.size () == 0 )
254316 throw exception (sycl::make_error_code (errc::invalid),
255317 " The launch work-group size cannot be zero." );
@@ -258,10 +320,17 @@ inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
258320 const auto &Handle = getHandleRef ();
259321 auto Device = Queue.get_device ();
260322
323+ size_t WG[Dimensions];
324+ WG[0 ] = WorkGroupSize[0 ];
325+ if constexpr (Dimensions >= 2 )
326+ WG[1 ] = WorkGroupSize[1 ];
327+ if constexpr (Dimensions == 3 )
328+ WG[2 ] = WorkGroupSize[2 ];
329+
261330 uint32_t GroupCount{0 };
262331 if (auto Result = Adapter->call_nocheck <
263332 UrApiKind::urKernelSuggestMaxCooperativeGroupCountExp>(
264- Handle, WorkGroupSize. size () , DynamicLocalMemorySize, &GroupCount);
333+ Handle, Dimensions, WG , DynamicLocalMemorySize, &GroupCount);
265334 Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
266335 // The feature is supported. Check for other errors and throw if any.
267336 Adapter->checkUrResult (Result);
@@ -277,15 +346,33 @@ inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
277346}
278347
279348template <>
280- inline typename syclex::info::kernel_queue_specific::max_num_work_group_sync ::
349+ inline typename syclex::info::kernel_queue_specific::max_num_work_groups ::
281350 return_type
282351 kernel_impl::ext_oneapi_get_info<
283- syclex::info::kernel_queue_specific::max_num_work_group_sync>(
352+ syclex::info::kernel_queue_specific::max_num_work_groups>(
353+ queue Queue, const range<1 > &WorkGroupSize,
354+ size_t DynamicLocalMemorySize) const {
355+ return queryMaxNumWorkGroups (Queue, WorkGroupSize, DynamicLocalMemorySize);
356+ }
357+
358+ template <>
359+ inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
360+ return_type
361+ kernel_impl::ext_oneapi_get_info<
362+ syclex::info::kernel_queue_specific::max_num_work_groups>(
363+ queue Queue, const range<2 > &WorkGroupSize,
364+ size_t DynamicLocalMemorySize) const {
365+ return queryMaxNumWorkGroups (Queue, WorkGroupSize, DynamicLocalMemorySize);
366+ }
367+
368+ template <>
369+ inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
370+ return_type
371+ kernel_impl::ext_oneapi_get_info<
372+ syclex::info::kernel_queue_specific::max_num_work_groups>(
284373 queue Queue, const range<3 > &WorkGroupSize,
285374 size_t DynamicLocalMemorySize) const {
286- return ext_oneapi_get_info<
287- syclex::info::kernel_queue_specific::max_num_work_groups>(
288- Queue, WorkGroupSize, DynamicLocalMemorySize);
375+ return queryMaxNumWorkGroups (Queue, WorkGroupSize, DynamicLocalMemorySize);
289376}
290377
291378template <>
@@ -299,7 +386,7 @@ inline typename syclex::info::kernel_queue_specific::max_num_work_group_sync::
299386 get_info<info::kernel_device_specific::work_group_size>(Device);
300387 const sycl::range<3 > WorkGroupSize{MaxWorkGroupSize, 1 , 1 };
301388 return ext_oneapi_get_info<
302- syclex::info::kernel_queue_specific::max_num_work_group_sync >(
389+ syclex::info::kernel_queue_specific::max_num_work_groups >(
303390 Queue, WorkGroupSize, /* DynamicLocalMemorySize */ 0 );
304391}
305392
0 commit comments