@@ -76,9 +76,11 @@ ur_exp_command_buffer_command_handle_t_::
7676 ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
7777 CUgraphNode Node, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim,
7878 const size_t *GlobalWorkOffsetPtr, const size_t *GlobalWorkSizePtr,
79- const size_t *LocalWorkSizePtr)
80- : CommandBuffer(CommandBuffer), Kernel(Kernel), Node(Node), Params(Params),
81- WorkDim(WorkDim), RefCountInternal(1 ), RefCountExternal(1 ) {
79+ const size_t *LocalWorkSizePtr, uint32_t NumKernelAlternatives,
80+ ur_kernel_handle_t *KernelAlternatives)
81+ : CommandBuffer(CommandBuffer), Kernel(Kernel), ValidKernelHandles(),
82+ Node(Node), Params(Params), WorkDim(WorkDim), RefCountInternal(1 ),
83+ RefCountExternal(1 ) {
8284 CommandBuffer->incrementInternalReferenceCount ();
8385
8486 const size_t CopySize = sizeof (size_t ) * WorkDim;
@@ -96,6 +98,13 @@ ur_exp_command_buffer_command_handle_t_::
9698 std::memset (GlobalWorkOffset + WorkDim, 0 , ZeroSize);
9799 std::memset (GlobalWorkSize + WorkDim, 0 , ZeroSize);
98100 }
101+
102+ /* Add the default Kernel as a valid kernel handle for this command */
103+ ValidKernelHandles.insert (Kernel);
104+ if (KernelAlternatives) {
105+ ValidKernelHandles.insert (KernelAlternatives,
106+ KernelAlternatives + NumKernelAlternatives);
107+ }
99108}
100109
101110// / Helper function for finding the Cuda Nodes associated with the
@@ -344,8 +353,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
344353 ur_exp_command_buffer_handle_t hCommandBuffer, ur_kernel_handle_t hKernel,
345354 uint32_t workDim, const size_t *pGlobalWorkOffset,
346355 const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
347- uint32_t /* numKernelAlternatives*/ ,
348- ur_kernel_handle_t * /* phKernelAlternatives*/ ,
356+ uint32_t numKernelAlternatives, ur_kernel_handle_t *phKernelAlternatives,
349357 uint32_t numSyncPointsInWaitList,
350358 const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
351359 ur_exp_command_buffer_sync_point_t *pSyncPoint,
@@ -356,6 +364,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
356364 UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
357365 UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
358366
367+ for (uint32_t i = 0 ; i < numKernelAlternatives; ++i) {
368+ UR_ASSERT (phKernelAlternatives[i] != hKernel,
369+ UR_RESULT_ERROR_INVALID_VALUE);
370+ }
371+
359372 CUgraphNode GraphNode;
360373
361374 std::vector<CUgraphNode> DepsList;
@@ -420,8 +433,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
420433 }
421434
422435 auto NewCommand = new ur_exp_command_buffer_command_handle_t_{
423- hCommandBuffer, hKernel, GraphNode, NodeParams,
424- workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize};
436+ hCommandBuffer, hKernel, GraphNode,
437+ NodeParams, workDim, pGlobalWorkOffset,
438+ pGlobalWorkSize, pLocalWorkSize, numKernelAlternatives,
439+ phKernelAlternatives};
425440
426441 NewCommand->incrementInternalReferenceCount ();
427442 hCommandBuffer->CommandHandles .push_back (NewCommand);
@@ -865,10 +880,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
865880 }
866881
867882 if (auto NewWorkDim = pUpdateKernelLaunch->newWorkDim ) {
868- // Error if work dim changes
869- if (NewWorkDim != hCommand->WorkDim ) {
870- return UR_RESULT_ERROR_INVALID_OPERATION;
871- }
872883
873884 // Error If Local size and not global size
874885 if ((pUpdateKernelLaunch->pNewLocalWorkSize != nullptr ) &&
@@ -888,7 +899,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
888899 }
889900
890901 // Kernel corresponding to the command to update
891- ur_kernel_handle_t Kernel = hCommand->Kernel ;
902+ ur_kernel_handle_t NewKernel = pUpdateKernelLaunch->hNewKernel ;
903+
904+ if (hCommand->ValidKernelHandles .count (NewKernel)) {
905+ hCommand->Kernel = NewKernel;
906+ } else {
907+ return UR_RESULT_ERROR_INVALID_VALUE;
908+ }
892909
893910 // Update pointer arguments to the kernel
894911 uint32_t NumPointerArgs = pUpdateKernelLaunch->numNewPointerArgs ;
@@ -901,7 +918,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
901918
902919 ur_result_t Result = UR_RESULT_SUCCESS;
903920 try {
904- Kernel ->setKernelArg (ArgIndex, sizeof (ArgValue), ArgValue);
921+ NewKernel ->setKernelArg (ArgIndex, sizeof (ArgValue), ArgValue);
905922 } catch (ur_result_t Err) {
906923 Result = Err;
907924 return Result;
@@ -920,11 +937,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
920937 ur_result_t Result = UR_RESULT_SUCCESS;
921938 try {
922939 if (ArgValue == nullptr ) {
923- Kernel ->setKernelArg (ArgIndex, 0 , nullptr );
940+ NewKernel ->setKernelArg (ArgIndex, 0 , nullptr );
924941 } else {
925942 CUdeviceptr CuPtr =
926943 std::get<BufferMem>(ArgValue->Mem ).getPtr (CommandBuffer->Device );
927- Kernel ->setKernelArg (ArgIndex, sizeof (CUdeviceptr), (void *)&CuPtr);
944+ NewKernel ->setKernelArg (ArgIndex, sizeof (CUdeviceptr), (void *)&CuPtr);
928945 }
929946 } catch (ur_result_t Err) {
930947 Result = Err;
@@ -945,7 +962,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
945962 ur_result_t Result = UR_RESULT_SUCCESS;
946963
947964 try {
948- Kernel ->setKernelArg (ArgIndex, ArgSize, ArgValue);
965+ NewKernel ->setKernelArg (ArgIndex, ArgSize, ArgValue);
949966 } catch (ur_result_t Err) {
950967 Result = Err;
951968 return Result;
@@ -985,12 +1002,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
9851002 // by default unless user has provided a better number
9861003 size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
9871004 size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
988- CUfunction CuFunc = Kernel ->get ();
1005+ CUfunction CuFunc = NewKernel ->get ();
9891006 ur_context_handle_t Context = CommandBuffer->Context ;
9901007 ur_device_handle_t Device = CommandBuffer->Device ;
9911008 auto Result = setKernelParams (Context, Device, WorkDim, GlobalWorkOffset,
992- GlobalWorkSize, LocalWorkSize, Kernel, CuFunc ,
993- ThreadsPerBlock, BlocksPerGrid);
1009+ GlobalWorkSize, LocalWorkSize, NewKernel ,
1010+ CuFunc, ThreadsPerBlock, BlocksPerGrid);
9941011 if (Result != UR_RESULT_SUCCESS) {
9951012 return Result;
9961013 }
@@ -1004,8 +1021,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
10041021 Params.blockDimX = ThreadsPerBlock[0 ];
10051022 Params.blockDimY = ThreadsPerBlock[1 ];
10061023 Params.blockDimZ = ThreadsPerBlock[2 ];
1007- Params.sharedMemBytes = Kernel ->getLocalSize ();
1008- Params.kernelParams = const_cast <void **>(Kernel ->getArgIndices ().data ());
1024+ Params.sharedMemBytes = NewKernel ->getLocalSize ();
1025+ Params.kernelParams = const_cast <void **>(NewKernel ->getArgIndices ().data ());
10091026
10101027 CUgraphNode Node = hCommand->Node ;
10111028 CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec ;
0 commit comments