@@ -882,18 +882,19 @@ validateCommandDesc(ur_exp_command_buffer_command_handle_t Command,
882882}
883883
884884/* *
885- * Updates the arguments of CommandDesc->hNewKernel
886- * @param[in] Device The device associated with the kernel being updated.
887- * @param[in] UpdateCommandDesc The update command description that contains
888- * the new kernel and its arguments.
885+ * Updates the arguments of a kernel command.
886+ * @param[in] Command The command associated with the kernel node being updated.
887+ * @param[in] UpdateCommandDesc The update command description that contains the
888+ * new arguments.
889889 * @return UR_RESULT_SUCCESS or an error code on failure
890890 */
891891ur_result_t
892- updateKernelArguments (ur_device_handle_t Device ,
892+ updateKernelArguments (ur_exp_command_buffer_command_handle_t Command ,
893893 const ur_exp_command_buffer_update_kernel_launch_desc_t
894894 *UpdateCommandDesc) {
895895
896- ur_kernel_handle_t NewKernel = UpdateCommandDesc->hNewKernel ;
896+ ur_kernel_handle_t Kernel = Command->Kernel ;
897+ ur_device_handle_t Device = Command->CommandBuffer ->Device ;
897898
898899 // Update pointer arguments to the kernel
899900 uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs ;
@@ -905,7 +906,7 @@ updateKernelArguments(ur_device_handle_t Device,
905906 const void *ArgValue = PointerArgDesc.pNewPointerArg ;
906907
907908 try {
908- NewKernel ->setKernelArg (ArgIndex, sizeof (ArgValue), ArgValue);
909+ Kernel ->setKernelArg (ArgIndex, sizeof (ArgValue), ArgValue);
909910 } catch (ur_result_t Err) {
910911 return Err;
911912 }
@@ -922,10 +923,10 @@ updateKernelArguments(ur_device_handle_t Device,
922923
923924 try {
924925 if (ArgValue == nullptr ) {
925- NewKernel ->setKernelArg (ArgIndex, 0 , nullptr );
926+ Kernel ->setKernelArg (ArgIndex, 0 , nullptr );
926927 } else {
927928 void *HIPPtr = std::get<BufferMem>(ArgValue->Mem ).getVoid (Device);
928- NewKernel ->setKernelArg (ArgIndex, sizeof (void *), (void *)&HIPPtr);
929+ Kernel ->setKernelArg (ArgIndex, sizeof (void *), (void *)&HIPPtr);
929930 }
930931 } catch (ur_result_t Err) {
931932 return Err;
@@ -943,7 +944,7 @@ updateKernelArguments(ur_device_handle_t Device,
943944 const void *ArgValue = ValueArgDesc.pNewValueArg ;
944945
945946 try {
946- NewKernel ->setKernelArg (ArgIndex, ArgSize, ArgValue);
947+ Kernel ->setKernelArg (ArgIndex, ArgSize, ArgValue);
947948 } catch (ur_result_t Err) {
948949 return Err;
949950 }
@@ -998,9 +999,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
998999 ur_exp_command_buffer_handle_t CommandBuffer = hCommand->CommandBuffer ;
9991000
10001001 UR_CHECK_ERROR (validateCommandDesc (hCommand, pUpdateKernelLaunch));
1001- UR_CHECK_ERROR (
1002- updateKernelArguments (CommandBuffer->Device , pUpdateKernelLaunch));
10031002 UR_CHECK_ERROR (updateCommand (hCommand, pUpdateKernelLaunch));
1003+ UR_CHECK_ERROR (updateKernelArguments (hCommand, pUpdateKernelLaunch));
10041004
10051005 // If no worksize is provided make sure we pass nullptr to setKernelParams
10061006 // so it can guess the local work size.
0 commit comments