@@ -119,18 +119,14 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
119119 GlobalSizeNormalized[i] = GlobalWorkSize[i];
120120 }
121121
122- size_t MaxBlockDim[3 ];
123- MaxBlockDim[0 ] = Device->getMaxWorkItemSizes (0 );
124- MaxBlockDim[1 ] = Device->getMaxWorkItemSizes (1 );
125- MaxBlockDim[2 ] = Device->getMaxWorkItemSizes (2 );
126-
127122 int MinGrid, MaxBlockSize;
128123 UR_CHECK_ERROR (cuOccupancyMaxPotentialBlockSize (
129124 &MinGrid, &MaxBlockSize, Kernel->get (), NULL , Kernel->getLocalSize (),
130- MaxBlockDim[ 0 ] ));
125+ Device-> getMaxWorkItemSizes ( 0 ) ));
131126
132127 roundToHighestFactorOfGlobalSizeIn3d (ThreadsPerBlock, GlobalSizeNormalized,
133- MaxBlockDim, MaxBlockSize);
128+ Device->getMaxWorkItemSizes (),
129+ MaxBlockSize);
134130}
135131
136132// Helper to verify out-of-registers case (exceeded block max registers).
@@ -145,7 +141,6 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
145141
146142// Helper to compute kernel parameters from workload
147143// dimensions.
148- // @param [in] Context handler to the target Context
149144// @param [in] Device handler to the target Device
150145// @param [in] WorkDim workload dimension
151146// @param [in] GlobalWorkOffset pointer workload global offsets
@@ -155,73 +150,56 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
155150// @param [out] ThreadsPerBlock Number of threads per block we should run
156151// @param [out] BlocksPerGrid Number of blocks per grid we should run
157152ur_result_t
158- setKernelParams ([[maybe_unused]] const ur_context_handle_t Context,
159- const ur_device_handle_t Device, const uint32_t WorkDim,
153+ setKernelParams (const ur_device_handle_t Device, const uint32_t WorkDim,
160154 const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize,
161155 const size_t *LocalWorkSize, ur_kernel_handle_t &Kernel,
162156 CUfunction &CuFunc, size_t (&ThreadsPerBlock)[3],
163157 size_t (&BlocksPerGrid)[3]) {
164- size_t MaxWorkGroupSize = 0u ;
165- bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr ;
166-
167158 try {
168159 // Set the active context here as guessLocalWorkSize needs an active context
169160 ScopedContext Active (Device);
170- {
171- size_t *MaxThreadsPerBlock = Kernel->MaxThreadsPerBlock ;
172- size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock ;
173- MaxWorkGroupSize = Device->getMaxWorkGroupSize ();
174-
175- if (ProvidedLocalWorkGroupSize) {
176- auto IsValid = [&](int Dim) {
177- if (ReqdThreadsPerBlock[Dim] != 0 &&
178- LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
179- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
180-
181- if (MaxThreadsPerBlock[Dim] != 0 &&
182- LocalWorkSize[Dim] > MaxThreadsPerBlock[Dim])
183- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
184-
185- if (LocalWorkSize[Dim] > Device->getMaxWorkItemSizes (Dim))
186- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
187- // Checks that local work sizes are a divisor of the global work sizes
188- // which includes that the local work sizes are neither larger than
189- // the global work sizes and not 0.
190- if (0u == LocalWorkSize[Dim])
191- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
192- if (0u != (GlobalWorkSize[Dim] % LocalWorkSize[Dim]))
193- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
194- ThreadsPerBlock[Dim] = LocalWorkSize[Dim];
195- return UR_RESULT_SUCCESS;
196- };
197-
198- size_t KernelLocalWorkGroupSize = 1 ;
199- for (size_t Dim = 0 ; Dim < WorkDim; Dim++) {
200- auto Err = IsValid (Dim);
201- if (Err != UR_RESULT_SUCCESS)
202- return Err;
203- // If no error then compute the total local work size as a product of
204- // all dims.
205- KernelLocalWorkGroupSize *= LocalWorkSize[Dim];
206- }
207161
208- if (size_t MaxLinearThreadsPerBlock = Kernel->MaxLinearThreadsPerBlock ;
209- MaxLinearThreadsPerBlock &&
210- MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
162+ if (LocalWorkSize != nullptr ) {
163+ size_t KernelLocalWorkGroupSize = 1 ;
164+ for (size_t i = 0 ; i < WorkDim; i++) {
165+ if (Kernel->ReqdThreadsPerBlock [i] &&
166+ Kernel->ReqdThreadsPerBlock [i] != LocalWorkSize[i])
211167 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
212- }
213168
214- if (hasExceededMaxRegistersPerBlock (Device, Kernel,
215- KernelLocalWorkGroupSize)) {
216- return UR_RESULT_ERROR_OUT_OF_RESOURCES;
217- }
218- } else {
219- guessLocalWorkSize (Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
220- Kernel);
169+ if (Kernel->MaxThreadsPerBlock [i] &&
170+ Kernel->MaxThreadsPerBlock [i] < LocalWorkSize[i])
171+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
172+
173+ if (LocalWorkSize[i] > Device->getMaxWorkItemSizes (i))
174+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
175+ // Checks that local work sizes are a divisor of the global work sizes
176+ // which includes that the local work sizes are neither larger than
177+ // the global work sizes and not 0.
178+ if (0u == LocalWorkSize[i] ||
179+ 0u != (GlobalWorkSize[i] % LocalWorkSize[i]))
180+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
181+
182+ ThreadsPerBlock[i] = LocalWorkSize[i];
183+
184+ // Compute the total local work size as a product of all is.
185+ KernelLocalWorkGroupSize *= LocalWorkSize[i];
221186 }
187+
188+ if (Kernel->MaxLinearThreadsPerBlock &&
189+ Kernel->MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
190+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
191+ }
192+
193+ if (hasExceededMaxRegistersPerBlock (Device, Kernel,
194+ KernelLocalWorkGroupSize)) {
195+ return UR_RESULT_ERROR_OUT_OF_RESOURCES;
196+ }
197+ } else {
198+ guessLocalWorkSize (Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
199+ Kernel);
222200 }
223201
224- if (MaxWorkGroupSize <
202+ if (Device-> getMaxWorkGroupSize () <
225203 ThreadsPerBlock[0 ] * ThreadsPerBlock[1 ] * ThreadsPerBlock[2 ]) {
226204 return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
227205 }
@@ -407,10 +385,9 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
407385
408386 // This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
409387 // using the standard UR_CHECK_ERROR
410- if (ur_result_t Ret =
411- setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
412- pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
413- hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
388+ if (ur_result_t Ret = setKernelParams (
389+ hQueue->Device , workDim, pGlobalWorkOffset, pGlobalWorkSize,
390+ pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
414391 Ret != UR_RESULT_SUCCESS)
415392 return Ret;
416393
@@ -595,10 +572,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
595572
596573 // This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
597574 // using the standard UR_CHECK_ERROR
598- if (ur_result_t Ret =
599- setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
600- pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
601- hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
575+ if (ur_result_t Ret = setKernelParams (
576+ hQueue->Device , workDim, pGlobalWorkOffset, pGlobalWorkSize,
577+ pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
602578 Ret != UR_RESULT_SUCCESS)
603579 return Ret;
604580
0 commit comments