@@ -140,7 +140,6 @@ ur_result_t setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
140140// dimension.
141141void guessLocalWorkSize (ur_device_handle_t Device, size_t *ThreadsPerBlock,
142142 const size_t *GlobalWorkSize, const uint32_t WorkDim,
143- const size_t MaxThreadsPerBlock[3 ],
144143 ur_kernel_handle_t Kernel) {
145144 assert (ThreadsPerBlock != nullptr );
146145 assert (GlobalWorkSize != nullptr );
@@ -154,14 +153,14 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
154153 }
155154
156155 size_t MaxBlockDim[3 ];
157- MaxBlockDim[0 ] = MaxThreadsPerBlock[ 0 ] ;
158- MaxBlockDim[1 ] = Device->getMaxBlockDimY ( );
159- MaxBlockDim[2 ] = Device->getMaxBlockDimZ ( );
156+ MaxBlockDim[0 ] = Device-> getMaxWorkItemSizes ( 0 ) ;
157+ MaxBlockDim[1 ] = Device->getMaxWorkItemSizes ( 1 );
158+ MaxBlockDim[2 ] = Device->getMaxWorkItemSizes ( 2 );
160159
161160 int MinGrid, MaxBlockSize;
162161 UR_CHECK_ERROR (cuOccupancyMaxPotentialBlockSize (
163162 &MinGrid, &MaxBlockSize, Kernel->get (), NULL , Kernel->getLocalSize (),
164- MaxThreadsPerBlock [0 ]));
163+ MaxBlockDim [0 ]));
165164
166165 roundToHighestFactorOfGlobalSizeIn3d (ThreadsPerBlock, GlobalSizeNormalized,
167166 MaxBlockDim, MaxBlockSize);
@@ -197,7 +196,6 @@ setKernelParams(const ur_context_handle_t Context,
197196 size_t (&BlocksPerGrid)[3]) {
198197 ur_result_t Result = UR_RESULT_SUCCESS;
199198 size_t MaxWorkGroupSize = 0u ;
200- size_t MaxThreadsPerBlock[3 ] = {};
201199 bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr ;
202200 uint32_t LocalSize = Kernel->getLocalSize ();
203201
@@ -207,16 +205,14 @@ setKernelParams(const ur_context_handle_t Context,
207205 {
208206 size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock ;
209207 MaxWorkGroupSize = Device->getMaxWorkGroupSize ();
210- Device->getMaxWorkItemSizes (sizeof (MaxThreadsPerBlock),
211- MaxThreadsPerBlock);
212208
213209 if (ProvidedLocalWorkGroupSize) {
214210 auto IsValid = [&](int Dim) {
215211 if (ReqdThreadsPerBlock[Dim] != 0 &&
216212 LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
217213 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
218214
219- if (LocalWorkSize[Dim] > MaxThreadsPerBlock[ Dim] )
215+ if (LocalWorkSize[Dim] > Device-> getMaxWorkItemSizes ( Dim) )
220216 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
221217 // Checks that local work sizes are a divisor of the global work sizes
222218 // which includes that the local work sizes are neither larger than
@@ -245,7 +241,7 @@ setKernelParams(const ur_context_handle_t Context,
245241 }
246242 } else {
247243 guessLocalWorkSize (Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
248- MaxThreadsPerBlock, Kernel);
244+ Kernel);
249245 }
250246 }
251247
0 commit comments