1818
1919#include < cmath>
2020#include < cuda.h>
21+ #include < ur/ur.hpp>
2122
2223ur_result_t enqueueEventsWait (ur_queue_handle_t CommandQueue, CUstream Stream,
2324 uint32_t NumEventsInWaitList,
@@ -140,12 +141,10 @@ ur_result_t setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
140141void guessLocalWorkSize (ur_device_handle_t Device, size_t *ThreadsPerBlock,
141142 const size_t *GlobalWorkSize, const uint32_t WorkDim,
142143 const size_t MaxThreadsPerBlock[3 ],
143- ur_kernel_handle_t Kernel, uint32_t LocalSize ) {
144+ ur_kernel_handle_t Kernel) {
144145 assert (ThreadsPerBlock != nullptr );
145146 assert (GlobalWorkSize != nullptr );
146147 assert (Kernel != nullptr );
147- int MinGrid, MaxBlockSize;
148- size_t MaxBlockDim[3 ];
149148
150149 // The below assumes a three dimensional range but this is not guaranteed by
151150 // UR.
@@ -154,33 +153,18 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
154153 GlobalSizeNormalized[i] = GlobalWorkSize[i];
155154 }
156155
156+ size_t MaxBlockDim[3 ];
157+ MaxBlockDim[0 ] = MaxThreadsPerBlock[0 ];
157158 MaxBlockDim[1 ] = Device->getMaxBlockDimY ();
158159 MaxBlockDim[2 ] = Device->getMaxBlockDimZ ();
159160
160- UR_CHECK_ERROR (
161- cuOccupancyMaxPotentialBlockSize (&MinGrid, &MaxBlockSize, Kernel->get (),
162- NULL , LocalSize, MaxThreadsPerBlock[0 ]));
163-
164- ThreadsPerBlock[2 ] = std::min (GlobalSizeNormalized[2 ], MaxBlockDim[2 ]);
165- ThreadsPerBlock[1 ] =
166- std::min (GlobalSizeNormalized[1 ],
167- std::min (MaxBlockSize / ThreadsPerBlock[2 ], MaxBlockDim[1 ]));
168- MaxBlockDim[0 ] = MaxBlockSize / (ThreadsPerBlock[1 ] * ThreadsPerBlock[2 ]);
169- ThreadsPerBlock[0 ] = std::min (
170- MaxThreadsPerBlock[0 ], std::min (GlobalSizeNormalized[0 ], MaxBlockDim[0 ]));
171-
172- static auto IsPowerOf2 = [](size_t Value) -> bool {
173- return Value && !(Value & (Value - 1 ));
174- };
175-
176- // Find a local work group size that is a divisor of the global
177- // work group size to produce uniform work groups.
178- // Additionally, for best compute utilisation, the local size has
179- // to be a power of two.
180- while (0u != (GlobalSizeNormalized[0 ] % ThreadsPerBlock[0 ]) ||
181- !IsPowerOf2 (ThreadsPerBlock[0 ])) {
182- --ThreadsPerBlock[0 ];
183- }
161+ int MinGrid, MaxBlockSize;
162+ UR_CHECK_ERROR (cuOccupancyMaxPotentialBlockSize (
163+ &MinGrid, &MaxBlockSize, Kernel->get (), NULL , Kernel->getLocalSize (),
164+ MaxThreadsPerBlock[0 ]));
165+
166+ roundToHighestFactorOfGlobalSizeIn3d (ThreadsPerBlock, GlobalSizeNormalized,
167+ MaxBlockDim, MaxBlockSize);
184168}
185169
186170// Helper to verify out-of-registers case (exceeded block max registers).
@@ -261,7 +245,7 @@ setKernelParams(const ur_context_handle_t Context,
261245 }
262246 } else {
263247 guessLocalWorkSize (Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
264- MaxThreadsPerBlock, Kernel, LocalSize );
248+ MaxThreadsPerBlock, Kernel);
265249 }
266250 }
267251
0 commit comments