Skip to content

Commit 43aa821

Browse files
committed
[Spec] fix urKernelSuggestMaxCooperativeGroupCountExp
Add extra param: ur_device_handle_t It is necessary to implement this function on L0 for kernels that are build for multiple devices. Right now, the implementation only works when the kernel is created from a native handle. Ref: #2262
1 parent 3d58884 commit 43aa821

File tree

16 files changed

+67
-17
lines changed

16 files changed

+67
-17
lines changed

include/ur_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9434,12 +9434,14 @@ urEnqueueCooperativeKernelLaunchExp(
94349434
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
94359435
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
94369436
/// + `NULL == hKernel`
9437+
/// + `NULL == hDevice`
94379438
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
94389439
/// + `NULL == pGroupCountRet`
94399440
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
94409441
UR_APIEXPORT ur_result_t UR_APICALL
94419442
urKernelSuggestMaxCooperativeGroupCountExp(
94429443
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
9444+
ur_device_handle_t hDevice, ///< [in] handle of the device object
94439445
size_t localWorkSize, ///< [in] number of local work-items that will form a work-group when the
94449446
///< kernel is launched
94459447
size_t dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
@@ -10687,6 +10689,7 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
1068710689
/// allowing the callback the ability to modify the parameter's value
1068810690
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
1068910691
ur_kernel_handle_t *phKernel;
10692+
ur_device_handle_t *phDevice;
1069010693
size_t *plocalWorkSize;
1069110694
size_t *pdynamicSharedMemorySize;
1069210695
uint32_t **ppGroupCountRet;

include/ur_ddi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
651651
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
652652
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
653653
ur_kernel_handle_t,
654+
ur_device_handle_t,
654655
size_t,
655656
size_t,
656657
uint32_t *);

include/ur_print.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12214,6 +12214,12 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1221412214
ur::details::printPtr(os,
1221512215
*(params->phKernel));
1221612216

12217+
os << ", ";
12218+
os << ".hDevice = ";
12219+
12220+
ur::details::printPtr(os,
12221+
*(params->phDevice));
12222+
1221712223
os << ", ";
1221812224
os << ".localWorkSize = ";
1221912225

scripts/core/exp-cooperative-kernels.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ params:
7878
- type: $x_kernel_handle_t
7979
name: hKernel
8080
desc: "[in] handle of the kernel object"
81+
- type: $x_device_handle_t
82+
name: hDevice
83+
desc: "[in] handle of the device object"
8184
- type: size_t
8285
name: localWorkSize
8386
desc: "[in] number of local work-items that will form a work-group when the kernel is launched"

source/adapters/cuda/kernel.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
190190
}
191191

192192
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
193-
ur_kernel_handle_t hKernel, size_t localWorkSize,
194-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
193+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
194+
size_t localWorkSize, size_t dynamicSharedMemorySize,
195+
uint32_t *pGroupCountRet) {
195196
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL);
196197

198+
std::ignore = hDevice;
199+
197200
// We need to set the active current device for this kernel explicitly here,
198201
// because the occupancy querying API does not take device parameter.
199202
ur_device_handle_t Device = hKernel->getProgram()->getDevice();

source/adapters/hip/kernel.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,11 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
169169
}
170170

171171
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
172-
ur_kernel_handle_t hKernel, size_t localWorkSize,
173-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
172+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
173+
size_t localWorkSize, size_t dynamicSharedMemorySize,
174+
uint32_t *pGroupCountRet) {
174175
std::ignore = hKernel;
176+
std::ignore = hDevice;
175177
std::ignore = localWorkSize;
176178
std::ignore = dynamicSharedMemorySize;
177179
std::ignore = pGroupCountRet;

source/adapters/level_zero/kernel.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,14 +1051,17 @@ ur_result_t urKernelGetNativeHandle(
10511051
}
10521052

10531053
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
1054-
ur_kernel_handle_t hKernel, size_t localWorkSize,
1055-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
1054+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
1055+
size_t localWorkSize, size_t dynamicSharedMemorySize,
1056+
uint32_t *pGroupCountRet) {
10561057
(void)localWorkSize;
10571058
(void)dynamicSharedMemorySize;
10581059
std::shared_lock<ur_shared_mutex> Guard(hKernel->Mutex);
10591060
uint32_t TotalGroupCount = 0;
1061+
ze_kernel_handle_t ZeKernel;
1062+
UR_CALL(getZeKernel(hDevice->ZeDevice, hKernel, &ZeKernel));
10601063
ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount,
1061-
(hKernel->ZeKernel, &TotalGroupCount));
1064+
(ZeKernel, &TotalGroupCount));
10621065
*pGroupCountRet = TotalGroupCount;
10631066
return UR_RESULT_SUCCESS;
10641067
}

source/adapters/level_zero/ur_interface_loader.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,9 @@ ur_result_t urEnqueueCooperativeKernelLaunchExp(
687687
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
688688
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent);
689689
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
690-
ur_kernel_handle_t hKernel, size_t localWorkSize,
691-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet);
690+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
691+
size_t localWorkSize, size_t dynamicSharedMemorySize,
692+
uint32_t *pGroupCountRet);
692693
ur_result_t urEnqueueTimestampRecordingExp(
693694
ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList,
694695
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent);

source/adapters/level_zero/v2/api.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,8 +568,9 @@ ur_result_t urCommandBufferCommandGetInfoExp(
568568
}
569569

570570
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
571-
ur_kernel_handle_t hKernel, size_t localWorkSize,
572-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
571+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
572+
size_t localWorkSize, size_t dynamicSharedMemorySize,
573+
uint32_t *pGroupCountRet) {
573574
logger::error("{} function not implemented!", __FUNCTION__);
574575
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
575576
}

source/adapters/mock/ur_mockddi.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10003,6 +10003,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
1000310003
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
1000410004
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
1000510005
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
10006+
ur_device_handle_t hDevice, ///< [in] handle of the device object
1000610007
size_t
1000710008
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
1000810009
///< kernel is launched
@@ -10014,7 +10015,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
1001410015
ur_result_t result = UR_RESULT_SUCCESS;
1001510016

1001610017
ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = {
10017-
&hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet};
10018+
&hKernel, &hDevice, &localWorkSize, &dynamicSharedMemorySize,
10019+
&pGroupCountRet};
1001810020

1001910021
auto beforeCallback = reinterpret_cast<ur_mock_callback_t>(
1002010022
mock::getCallbacks().get_before_callback(

0 commit comments

Comments
 (0)