@@ -30,12 +30,33 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
3030 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
3131 const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
3232 const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
33+ std::vector<size_t > compiledLocalWorksize;
34+ if (!pLocalWorkSize) {
35+ cl_device_id device = nullptr ;
36+ CL_RETURN_ON_FAILURE (clGetCommandQueueInfo (
37+ cl_adapter::cast<cl_command_queue>(hQueue), CL_QUEUE_DEVICE,
38+ sizeof (device), &device, nullptr ));
39+ // This query always returns size_t[3], if nothing was specified it returns
40+ // all zeroes.
41+ size_t queriedLocalWorkSize[3 ] = {0 , 0 , 0 };
42+ CL_RETURN_ON_FAILURE (clGetKernelWorkGroupInfo (
43+ cl_adapter::cast<cl_kernel>(hKernel), device,
44+ CL_KERNEL_COMPILE_WORK_GROUP_SIZE, sizeof (size_t [3 ]),
45+ queriedLocalWorkSize, nullptr ));
46+ if (queriedLocalWorkSize[0 ] != 0 ) {
47+ for (uint32_t i = 0 ; i < workDim; i++) {
48+ compiledLocalWorksize.push_back (queriedLocalWorkSize[i]);
49+ }
50+ }
51+ }
3352
3453 CL_RETURN_ON_FAILURE (clEnqueueNDRangeKernel (
3554 cl_adapter::cast<cl_command_queue>(hQueue),
3655 cl_adapter::cast<cl_kernel>(hKernel), workDim, pGlobalWorkOffset,
37- pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList,
38- cl_adapter::cast<const cl_event *>(phEventWaitList),
56+ pGlobalWorkSize,
57+ compiledLocalWorksize.empty () ? pLocalWorkSize
58+ : compiledLocalWorksize.data (),
59+ numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
3960 cl_adapter::cast<cl_event *>(phEvent)));
4061
4162 return UR_RESULT_SUCCESS;
0 commit comments