Skip to content

Commit 48a9ef1

Browse files
committed
[UR] Add default implementation for cooperative kernel functions
Cooperative kernels can synchronize using device-scope barriers. These kernels use `urKernelSuggestMaxCooperativeGroupCountExp` to ensure that all work groups can run concurrently. When the maximum number of work groups is 1, these kernels behave the same as regular kernels. This PR adds a default implementation for `urKernelSuggestMaxCooperativeGroupCountExp` that returns 1. Also, it adds a default implementation for `urEnqueueCooperativeKernelLaunchExp` that calls `urEnqueueKernelLaunch`. Signed-off-by: Michael Aziz <[email protected]>
1 parent e1414e1 commit 48a9ef1

File tree

7 files changed

+68
-0
lines changed

7 files changed

+68
-0
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
490490
return Result;
491491
}
492492

493+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
494+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
495+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
496+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
497+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
498+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
499+
pGlobalWorkSize, pLocalWorkSize,
500+
numEventsInWaitList, phEventWaitList, phEvent);
501+
}
502+
493503
/// Set parameters for general 3D memory copy.
494504
/// If the source and/or destination is on the device, SrcPtr and/or DstPtr
495505
/// must be a pointer to a CUdeviceptr

source/adapters/cuda/kernel.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
169169
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
170170
}
171171

172+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
173+
ur_kernel_handle_t hKernel, uint32_t *pGroupCountRet) {
174+
(void)hKernel;
175+
*pGroupCountRet = 1;
176+
return UR_RESULT_SUCCESS;
177+
}
178+
172179
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
173180
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
174181
const ur_kernel_arg_value_properties_t *pProperties,

source/adapters/hip/enqueue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
475475
return Result;
476476
}
477477

478+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
479+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
480+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
481+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
482+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
483+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
484+
pGlobalWorkSize, pLocalWorkSize,
485+
numEventsInWaitList, phEventWaitList, phEvent);
486+
}
487+
478488
/// Enqueues a wait on the given queue for all events.
479489
/// See \ref enqueueEventWait
480490
///

source/adapters/hip/kernel.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
158158
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
159159
}
160160

161+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
162+
ur_kernel_handle_t hKernel, uint32_t *pGroupCountRet) {
163+
(void)hKernel;
164+
*pGroupCountRet = 1;
165+
return UR_RESULT_SUCCESS;
166+
}
167+
161168
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
162169
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
163170
const ur_kernel_arg_value_properties_t *, const void *pArgValue) {

source/adapters/level_zero/kernel.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
253253
return UR_RESULT_SUCCESS;
254254
}
255255

256+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
257+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
258+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
259+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
260+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
261+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
262+
pGlobalWorkSize, pLocalWorkSize,
263+
numEventsInWaitList, phEventWaitList, phEvent);
264+
}
265+
256266
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
257267
ur_queue_handle_t Queue, ///< [in] handle of the queue to submit to.
258268
ur_program_handle_t Program, ///< [in] handle of the program containing the
@@ -736,6 +746,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
736746
return UR_RESULT_SUCCESS;
737747
}
738748

749+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
750+
ur_kernel_handle_t hKernel, uint32_t *pGroupCountRet) {
751+
(void)hKernel;
752+
*pGroupCountRet = 1;
753+
return UR_RESULT_SUCCESS;
754+
}
755+
739756
UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
740757
ur_native_handle_t NativeKernel, ///< [in] the native handle of the kernel.
741758
ur_context_handle_t Context, ///< [in] handle of the context object

source/adapters/opencl/enqueue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
4141
return UR_RESULT_SUCCESS;
4242
}
4343

44+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
45+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
46+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
47+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
48+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
49+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
50+
pGlobalWorkSize, pLocalWorkSize,
51+
numEventsInWaitList, phEventWaitList, phEvent);
52+
}
53+
4454
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
4555
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
4656
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

source/adapters/opencl/kernel.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
376376
return UR_RESULT_SUCCESS;
377377
}
378378

379+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
380+
ur_kernel_handle_t hKernel, uint32_t *pGroupCountRet) {
381+
(void)hKernel;
382+
*pGroupCountRet = 1;
383+
return UR_RESULT_SUCCESS;
384+
}
385+
379386
UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
380387
ur_native_handle_t hNativeKernel, ur_context_handle_t, ur_program_handle_t,
381388
const ur_kernel_native_properties_t *pProperties,

0 commit comments

Comments
 (0)