File tree Expand file tree Collapse file tree 3 files changed +431
-114
lines changed
exp_command_buffer/update Expand file tree Collapse file tree 3 files changed +431
-114
lines changed Original file line number Diff line number Diff line change @@ -1396,14 +1396,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13961396
13971397 CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params ;
13981398
1399+ const auto LocalSize = KernelCommandHandle->Kernel ->getLocalSize ();
1400+ if (LocalSize != 0 ) {
1401+ // Clean the local size, otherwise calling updateKernelArguments() in
1402+ // future updates with local arguments will incorrectly increase the
1403+ // size further.
1404+ KernelCommandHandle->Kernel ->clearLocalSize ();
1405+ }
1406+
13991407 Params.func = CuFunc;
1400- Params.gridDimX = BlocksPerGrid[0 ];
1401- Params.gridDimY = BlocksPerGrid[1 ];
1402- Params.gridDimZ = BlocksPerGrid[2 ];
1403- Params.blockDimX = ThreadsPerBlock[0 ];
1404- Params.blockDimY = ThreadsPerBlock[1 ];
1405- Params.blockDimZ = ThreadsPerBlock[2 ];
1406- Params.sharedMemBytes = KernelCommandHandle-> Kernel -> getLocalSize () ;
1408+ Params.gridDimX = static_cast < unsigned int >( BlocksPerGrid[0 ]) ;
1409+ Params.gridDimY = static_cast < unsigned int >( BlocksPerGrid[1 ]) ;
1410+ Params.gridDimZ = static_cast < unsigned int >( BlocksPerGrid[2 ]) ;
1411+ Params.blockDimX = static_cast < unsigned int >( ThreadsPerBlock[0 ]) ;
1412+ Params.blockDimY = static_cast < unsigned int >( ThreadsPerBlock[1 ]) ;
1413+ Params.blockDimZ = static_cast < unsigned int >( ThreadsPerBlock[2 ]) ;
1414+ Params.sharedMemBytes = LocalSize ;
14071415 Params.kernelParams =
14081416 const_cast <void **>(KernelCommandHandle->Kernel ->getArgIndices ().data ());
14091417
Original file line number Diff line number Diff line change @@ -15,15 +15,27 @@ int main() {
1515 uint32_t A = 42 ;
1616
1717 sycl_queue.submit ([&](sycl::handler &cgh) {
18- sycl::local_accessor<uint32_t , 1 > local_mem (local_size, cgh);
18+ sycl::local_accessor<uint32_t , 1 > local_mem_A (local_size, cgh);
19+ sycl::local_accessor<uint32_t , 1 > local_mem_B (1 , cgh);
20+
1921 cgh.parallel_for <class saxpy_usm_local_mem >(
2022 sycl::nd_range<1 >{{array_size}, {local_size}},
2123 [=](sycl::nd_item<1 > itemId) {
2224 auto i = itemId.get_global_linear_id ();
2325 auto local_id = itemId.get_local_linear_id ();
24- local_mem[local_id] = i;
25- Z[i] = A * X[i] + Y[i] + local_mem[local_id] +
26+
27+ local_mem_A[local_id] = i;
28+ if (i == 0 ) {
29+ local_mem_B[0 ] = 0xA ;
30+ }
31+
32+ Z[i] = A * X[i] + Y[i] + local_mem_A[local_id] +
2633 itemId.get_local_range (0 );
34+
35+ if (i == 0 ) {
36+ Z[i] += local_mem_B[0 ];
37+ }
38+
2739 });
2840 });
2941 return 0 ;
You can’t perform that action at this time.
0 commit comments