Skip to content

Commit 9eeda47

Browse files
committed
Add parameters to cooperative kernel query
Signed-off-by: Michael Aziz <[email protected]>
1 parent 48a9ef1 commit 9eeda47

File tree

14 files changed

+96
-20
lines changed

14 files changed

+96
-20
lines changed

include/ur_api.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8419,8 +8419,12 @@ urEnqueueCooperativeKernelLaunchExp(
84198419
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
84208420
UR_APIEXPORT ur_result_t UR_APICALL
84218421
urKernelSuggestMaxCooperativeGroupCountExp(
8422-
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8423-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
8422+
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8423+
size_t localWorkSize, ///< [in] number of local work-items that will form a work-group when the
8424+
///< kernel is launched
8425+
size_t dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
8426+
///< that will be used when the kernel is launched
8427+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
84248428
);
84258429

84268430
#if !defined(__GNUC__)
@@ -9368,6 +9372,8 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
93689372
/// allowing the callback the ability to modify the parameter's value
93699373
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
93709374
ur_kernel_handle_t *phKernel;
9375+
size_t *plocalWorkSize;
9376+
size_t *pdynamicSharedMemorySize;
93719377
uint32_t **ppGroupCountRet;
93729378
} ur_kernel_suggest_max_cooperative_group_count_exp_params_t;
93739379

include/ur_ddi.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,8 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
627627
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
628628
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
629629
ur_kernel_handle_t,
630+
size_t,
631+
size_t,
630632
uint32_t *);
631633

632634
///////////////////////////////////////////////////////////////////////////////

include/ur_print.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10868,6 +10868,16 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1086810868
ur::details::printPtr(os,
1086910869
*(params->phKernel));
1087010870

10871+
os << ", ";
10872+
os << ".localWorkSize = ";
10873+
10874+
os << *(params->plocalWorkSize);
10875+
10876+
os << ", ";
10877+
os << ".dynamicSharedMemorySize = ";
10878+
10879+
os << *(params->pdynamicSharedMemorySize);
10880+
1087110881
os << ", ";
1087210882
os << ".pGroupCountRet = ";
1087310883

scripts/core/exp-cooperative-kernels.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ params:
7878
- type: $x_kernel_handle_t
7979
name: hKernel
8080
desc: "[in] handle of the kernel object"
81+
- type: size_t
82+
name: localWorkSize
83+
desc: "[in] number of local work-items that will form a work-group when the kernel is launched"
84+
- type: size_t
85+
name: dynamicSharedMemorySize
86+
desc: "[in] size of dynamic shared memory, for each work-group, in bytes, that will be used when the kernel is launched"
8187
- type: "uint32_t*"
8288
name: "pGroupCountRet"
8389
desc: "[out] pointer to maximum number of groups"

source/adapters/cuda/kernel.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
170170
}
171171

172172
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
173-
ur_kernel_handle_t hKernel, uint32_t *pGroupCountRet) {
173+
ur_kernel_handle_t hKernel, size_t localWorkSize,
174+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
174175
(void)hKernel;
176+
(void)localWorkSize;
177+
(void)dynamicSharedMemorySize;
175178
*pGroupCountRet = 1;
176179
return UR_RESULT_SUCCESS;
177180
}

source/adapters/hip/kernel.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,11 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
159159
}
160160

161161
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
162-
ur_kernel_handle_t hKernel, uint32_t *pGroupCountRet) {
162+
ur_kernel_handle_t hKernel, size_t localWorkSize,
163+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
163164
(void)hKernel;
165+
(void)localWorkSize;
166+
(void)dynamicSharedMemorySize;
164167
*pGroupCountRet = 1;
165168
return UR_RESULT_SUCCESS;
166169
}

source/adapters/level_zero/kernel.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
747747
}
748748

749749
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
750-
ur_kernel_handle_t hKernel, uint32_t *pGroupCountRet) {
750+
ur_kernel_handle_t hKernel, size_t localWorkSize,
751+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
751752
(void)hKernel;
753+
(void)localWorkSize;
754+
(void)dynamicSharedMemorySize;
752755
*pGroupCountRet = 1;
753756
return UR_RESULT_SUCCESS;
754757
}

source/adapters/null/ur_nullddi.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5073,15 +5073,22 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
50735073
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
50745074
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
50755075
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
5076-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
5076+
size_t
5077+
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
5078+
///< kernel is launched
5079+
size_t
5080+
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
5081+
///< that will be used when the kernel is launched
5082+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
50775083
) try {
50785084
ur_result_t result = UR_RESULT_SUCCESS;
50795085

50805086
// if the driver has created a custom function, then call it instead of using the generic path
50815087
auto pfnSuggestMaxCooperativeGroupCountExp =
50825088
d_context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp;
50835089
if (nullptr != pfnSuggestMaxCooperativeGroupCountExp) {
5084-
result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
5090+
result = pfnSuggestMaxCooperativeGroupCountExp(
5091+
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
50855092
} else {
50865093
// generic implementation
50875094
}

source/adapters/opencl/kernel.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "common.hpp"
1111

1212
#include <algorithm>
13+
#include <cstddef>
1314
#include <memory>
1415

1516
UR_APIEXPORT ur_result_t UR_APICALL
@@ -377,8 +378,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
377378
}
378379

379380
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
380-
ur_kernel_handle_t hKernel, uint32_t *pGroupCountRet) {
381+
ur_kernel_handle_t hKernel, size_t localWorkSize,
382+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
381383
(void)hKernel;
384+
(void)localWorkSize;
385+
(void)dynamicSharedMemorySize;
382386
*pGroupCountRet = 1;
383387
return UR_RESULT_SUCCESS;
384388
}

source/loader/layers/tracing/ur_trcddi.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5874,7 +5874,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
58745874
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
58755875
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
58765876
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
5877-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
5877+
size_t
5878+
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
5879+
///< kernel is launched
5880+
size_t
5881+
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
5882+
///< that will be used when the kernel is launched
5883+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
58785884
) {
58795885
auto pfnSuggestMaxCooperativeGroupCountExp =
58805886
context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp;
@@ -5884,13 +5890,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
58845890
}
58855891

58865892
ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = {
5887-
&hKernel, &pGroupCountRet};
5893+
&hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet};
58885894
uint64_t instance = context.notify_begin(
58895895
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP,
58905896
"urKernelSuggestMaxCooperativeGroupCountExp", &params);
58915897

5892-
ur_result_t result =
5893-
pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
5898+
ur_result_t result = pfnSuggestMaxCooperativeGroupCountExp(
5899+
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
58945900

58955901
context.notify_end(
58965902
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP,

0 commit comments

Comments
 (0)