@@ -123,6 +123,32 @@ class kernel_impl {
123123 template <typename Param>
124124 typename Param::return_type ext_oneapi_get_info (queue Queue) const ;
125125
126+ // / Query queue/launch-specific information from a kernel using the
127+ // / info::kernel_queue_specific descriptor for a specific Queue and values.
128+ // / max_num_work_groups is the only valid descriptor for this function.
129+ // /
130+ // / \param Queue is a valid SYCL queue.
131+ // / \param WorkGroupSize is the work-group size the number of work-groups is
132+ // / requested for.
133+ // / \return depends on information being queried.
134+ template <typename Param>
135+ typename Param::return_type
136+ ext_oneapi_get_info (queue Queue, const range<1 > &MaxWorkGroupSize,
137+ size_t DynamicLocalMemorySize) const ;
138+
139+ // / Query queue/launch-specific information from a kernel using the
140+ // / info::kernel_queue_specific descriptor for a specific Queue and values.
141+ // / max_num_work_groups is the only valid descriptor for this function.
142+ // /
143+ // / \param Queue is a valid SYCL queue.
144+ // / \param WorkGroupSize is the work-group size the number of work-groups is
145+ // / requested for.
146+ // / \return depends on information being queried.
147+ template <typename Param>
148+ typename Param::return_type
149+ ext_oneapi_get_info (queue Queue, const range<2 > &MaxWorkGroupSize,
150+ size_t DynamicLocalMemorySize) const ;
151+
126152 // / Query queue/launch-specific information from a kernel using the
127153 // / info::kernel_queue_specific descriptor for a specific Queue and values.
128154 // / max_num_work_groups is the only valid descriptor for this function.
@@ -193,11 +219,49 @@ class kernel_impl {
193219
194220 // / Check if the occupancy limits are exceeded for the given kernel launch
195221 // / configuration.
222+ template <int Dimensions>
196223 bool exceedsOccupancyResourceLimits (const device &Device,
197- const range<3 > &WorkGroupSize,
224+ const range<Dimensions > &WorkGroupSize,
198225 size_t DynamicLocalMemorySize) const ;
226+ template <int Dimensions>
227+ size_t queryMaxNumWorkGroups (queue Queue,
228+ const range<Dimensions> &WorkGroupSize,
229+ size_t DynamicLocalMemorySize) const ;
199230};
200231
232+ template <int Dimensions>
233+ bool kernel_impl::exceedsOccupancyResourceLimits (
234+ const device &Device, const range<Dimensions> &WorkGroupSize,
235+ size_t DynamicLocalMemorySize) const {
236+ // Respect occupancy limits for WorkGroupSize and DynamicLocalMemorySize.
237+ // Generally, exceeding hardware resource limits will yield in an error when
238+ // the kernel is launched.
239+ const size_t MaxWorkGroupSize =
240+ get_info<info::kernel_device_specific::work_group_size>(Device);
241+ const size_t MaxLocalMemorySizeInBytes =
242+ Device.get_info <info::device::local_mem_size>();
243+
244+ if (WorkGroupSize.size () > MaxWorkGroupSize)
245+ return true ;
246+
247+ if (DynamicLocalMemorySize > MaxLocalMemorySizeInBytes)
248+ return true ;
249+
250+ // It will be impossible to launch a kernel for Cuda when the hardware limit
251+ // for the 32-bit registers page file size is exceeded.
252+ if (Device.get_backend () == backend::ext_oneapi_cuda) {
253+ const uint32_t RegsPerWorkItem =
254+ get_info<info::kernel_device_specific::ext_codeplay_num_regs>(Device);
255+ const uint32_t MaxRegsPerWorkGroup =
256+ Device.get_info <ext::codeplay::experimental::info::device::
257+ max_registers_per_work_group>();
258+ if ((MaxWorkGroupSize * RegsPerWorkItem) > MaxRegsPerWorkGroup)
259+ return true ;
260+ }
261+
262+ return false ;
263+ }
264+
201265template <typename Param>
202266inline typename Param::return_type kernel_impl::get_info () const {
203267 static_assert (is_kernel_info_desc<Param>::value,
@@ -244,13 +308,11 @@ kernel_impl::get_info(const device &Device,
244308
245309namespace syclex = ext::oneapi::experimental;
246310
247- template <>
248- inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
249- return_type
250- kernel_impl::ext_oneapi_get_info<
251- syclex::info::kernel_queue_specific::max_num_work_groups>(
252- queue Queue, const range<3 > &WorkGroupSize,
253- size_t DynamicLocalMemorySize) const {
311+ template <int Dimensions>
312+ size_t
313+ kernel_impl::queryMaxNumWorkGroups (queue Queue,
314+ const range<Dimensions> &WorkGroupSize,
315+ size_t DynamicLocalMemorySize) const {
254316 if (WorkGroupSize.size () == 0 )
255317 throw exception (sycl::make_error_code (errc::invalid),
256318 " The launch work-group size cannot be zero." );
@@ -259,12 +321,21 @@ inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
259321 const auto &Handle = getHandleRef ();
260322 auto Device = Queue.get_device ();
261323
324+ size_t WG[Dimensions];
325+ WG[0 ] = WorkGroupSize[0 ];
326+ if constexpr (Dimensions >= 2 )
327+ WG[1 ] = WorkGroupSize[1 ];
328+ if constexpr (Dimensions == 3 )
329+ WG[2 ] = WorkGroupSize[2 ];
330+
262331 uint32_t GroupCount{0 };
263332 if (auto Result = Adapter->call_nocheck <
264333 UrApiKind::urKernelSuggestMaxCooperativeGroupCountExp>(
265- Handle, WorkGroupSize.size (), DynamicLocalMemorySize, &GroupCount);
266- Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
267- // The feature is supported. Check for other errors and throw if any.
334+ Handle, Dimensions, WG, DynamicLocalMemorySize, &GroupCount);
335+ Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE &&
336+ Result != UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE) {
337+ // The feature is supported and the group size is valid. Check for other
338+ // errors and throw if any.
268339 Adapter->checkUrResult (Result);
269340 return GroupCount;
270341 }
@@ -278,30 +349,33 @@ inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
278349}
279350
280351template <>
281- inline typename syclex::info::kernel_queue_specific::max_num_work_group_sync ::
352+ inline typename syclex::info::kernel_queue_specific::max_num_work_groups ::
282353 return_type
283354 kernel_impl::ext_oneapi_get_info<
284- syclex::info::kernel_queue_specific::max_num_work_group_sync >(
285- queue Queue, const range<3 > &WorkGroupSize,
355+ syclex::info::kernel_queue_specific::max_num_work_groups >(
356+ queue Queue, const range<1 > &WorkGroupSize,
286357 size_t DynamicLocalMemorySize) const {
287- return ext_oneapi_get_info<
288- syclex::info::kernel_queue_specific::max_num_work_groups>(
289- Queue, WorkGroupSize, DynamicLocalMemorySize);
358+ return queryMaxNumWorkGroups (Queue, WorkGroupSize, DynamicLocalMemorySize);
290359}
291360
292361template <>
293- inline typename syclex::info::kernel_queue_specific::max_num_work_group_sync ::
362+ inline typename syclex::info::kernel_queue_specific::max_num_work_groups ::
294363 return_type
295364 kernel_impl::ext_oneapi_get_info<
296- syclex::info::kernel_queue_specific::max_num_work_group_sync>(
297- queue Queue) const {
298- auto Device = Queue.get_device ();
299- const auto MaxWorkGroupSize =
300- get_info<info::kernel_device_specific::work_group_size>(Device);
301- const sycl::range<3 > WorkGroupSize{MaxWorkGroupSize, 1 , 1 };
302- return ext_oneapi_get_info<
303- syclex::info::kernel_queue_specific::max_num_work_group_sync>(
304- Queue, WorkGroupSize, /* DynamicLocalMemorySize */ 0 );
365+ syclex::info::kernel_queue_specific::max_num_work_groups>(
366+ queue Queue, const range<2 > &WorkGroupSize,
367+ size_t DynamicLocalMemorySize) const {
368+ return queryMaxNumWorkGroups (Queue, WorkGroupSize, DynamicLocalMemorySize);
369+ }
370+
371+ template <>
372+ inline typename syclex::info::kernel_queue_specific::max_num_work_groups::
373+ return_type
374+ kernel_impl::ext_oneapi_get_info<
375+ syclex::info::kernel_queue_specific::max_num_work_groups>(
376+ queue Queue, const range<3 > &WorkGroupSize,
377+ size_t DynamicLocalMemorySize) const {
378+ return queryMaxNumWorkGroups (Queue, WorkGroupSize, DynamicLocalMemorySize);
305379}
306380
307381} // namespace detail
0 commit comments