diff --git a/source/adapters/cuda/command_buffer.cpp b/source/adapters/cuda/command_buffer.cpp index e47bcf9c2a..2029903c92 100644 --- a/source/adapters/cuda/command_buffer.cpp +++ b/source/adapters/cuda/command_buffer.cpp @@ -1237,18 +1237,19 @@ validateCommandDesc(kernel_command_handle *Command, } /** - * Updates the arguments of CommandDesc->hNewKernel - * @param[in] Device The device associated with the kernel being updated. + * Updates the arguments of a kernel command. + * @param[in] Command The command associated with the kernel node being updated. * @param[in] UpdateCommandDesc The update command description that contains the - * new kernel and its arguments. + * new arguments. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t -updateKernelArguments(ur_device_handle_t Device, +updateKernelArguments(kernel_command_handle *Command, const ur_exp_command_buffer_update_kernel_launch_desc_t *UpdateCommandDesc) { - ur_kernel_handle_t NewKernel = UpdateCommandDesc->hNewKernel; + ur_kernel_handle_t Kernel = Command->Kernel; + ur_device_handle_t Device = Command->CommandBuffer->Device; // Update pointer arguments to the kernel uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs; @@ -1261,7 +1262,7 @@ updateKernelArguments(ur_device_handle_t Device, ur_result_t Result = UR_RESULT_SUCCESS; try { - NewKernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue); + Kernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue); } catch (ur_result_t Err) { Result = Err; return Result; @@ -1280,10 +1281,10 @@ updateKernelArguments(ur_device_handle_t Device, ur_result_t Result = UR_RESULT_SUCCESS; try { if (ArgValue == nullptr) { - NewKernel->setKernelArg(ArgIndex, 0, nullptr); + Kernel->setKernelArg(ArgIndex, 0, nullptr); } else { CUdeviceptr CuPtr = std::get(ArgValue->Mem).getPtr(Device); - NewKernel->setKernelArg(ArgIndex, sizeof(CUdeviceptr), (void *)&CuPtr); + Kernel->setKernelArg(ArgIndex, sizeof(CUdeviceptr), (void *)&CuPtr); } } catch (ur_result_t Err) { Result = Err; @@ -1303,7 +1304,7 @@ updateKernelArguments(ur_device_handle_t Device, ur_result_t Result = UR_RESULT_SUCCESS; try { - NewKernel->setKernelArg(ArgIndex, ArgSize, ArgValue); + Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue); } catch (ur_result_t Err) { Result = Err; return Result; @@ -1364,9 +1365,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( auto KernelCommandHandle = static_cast(hCommand); UR_CHECK_ERROR(validateCommandDesc(KernelCommandHandle, pUpdateKernelLaunch)); - UR_CHECK_ERROR( - updateKernelArguments(CommandBuffer->Device, pUpdateKernelLaunch)); UR_CHECK_ERROR(updateCommand(KernelCommandHandle, pUpdateKernelLaunch)); + UR_CHECK_ERROR( + updateKernelArguments(KernelCommandHandle, pUpdateKernelLaunch)); // If no work-size is provided make sure we pass nullptr to setKernelParams so // it can guess the local work size. diff --git a/source/adapters/hip/command_buffer.cpp b/source/adapters/hip/command_buffer.cpp index 9ecb1a5477..afd15c1bd4 100644 --- a/source/adapters/hip/command_buffer.cpp +++ b/source/adapters/hip/command_buffer.cpp @@ -951,18 +951,19 @@ validateCommandDesc(ur_exp_command_buffer_command_handle_t Command, } /** - * Updates the arguments of CommandDesc->hNewKernel - * @param[in] Device The device associated with the kernel being updated. - * @param[in] UpdateCommandDesc The update command description that contains - * the new kernel and its arguments. + * Updates the arguments of a kernel command. + * @param[in] Command The command associated with the kernel node being updated. + * @param[in] UpdateCommandDesc The update command description that contains the + * new arguments. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t -updateKernelArguments(ur_device_handle_t Device, +updateKernelArguments(ur_exp_command_buffer_command_handle_t Command, const ur_exp_command_buffer_update_kernel_launch_desc_t *UpdateCommandDesc) { - ur_kernel_handle_t NewKernel = UpdateCommandDesc->hNewKernel; + ur_kernel_handle_t Kernel = Command->Kernel; + ur_device_handle_t Device = Command->CommandBuffer->Device; // Update pointer arguments to the kernel uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs; @@ -974,7 +975,7 @@ updateKernelArguments(ur_device_handle_t Device, const void *ArgValue = PointerArgDesc.pNewPointerArg; try { - NewKernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue); + Kernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue); } catch (ur_result_t Err) { return Err; } @@ -991,10 +992,10 @@ updateKernelArguments(ur_device_handle_t Device, try { if (ArgValue == nullptr) { - NewKernel->setKernelArg(ArgIndex, 0, nullptr); + Kernel->setKernelArg(ArgIndex, 0, nullptr); } else { void *HIPPtr = std::get(ArgValue->Mem).getVoid(Device); - NewKernel->setKernelArg(ArgIndex, sizeof(void *), (void *)&HIPPtr); + Kernel->setKernelArg(ArgIndex, sizeof(void *), (void *)&HIPPtr); } } catch (ur_result_t Err) { return Err; @@ -1012,7 +1013,7 @@ updateKernelArguments(ur_device_handle_t Device, const void *ArgValue = ValueArgDesc.pNewValueArg; try { - NewKernel->setKernelArg(ArgIndex, ArgSize, ArgValue); + Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue); } catch (ur_result_t Err) { return Err; } @@ -1067,9 +1068,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( ur_exp_command_buffer_handle_t CommandBuffer = hCommand->CommandBuffer; UR_CHECK_ERROR(validateCommandDesc(hCommand, pUpdateKernelLaunch)); - UR_CHECK_ERROR( - updateKernelArguments(CommandBuffer->Device, pUpdateKernelLaunch)); UR_CHECK_ERROR(updateCommand(hCommand, pUpdateKernelLaunch)); + UR_CHECK_ERROR(updateKernelArguments(hCommand, pUpdateKernelLaunch)); // If no worksize is provided make sure we pass nullptr to setKernelParams // so it can guess the local work size. diff --git a/test/conformance/exp_command_buffer/update/kernel_handle_update.cpp b/test/conformance/exp_command_buffer/update/kernel_handle_update.cpp index c74af937f6..9fb408fb42 100644 --- a/test/conformance/exp_command_buffer/update/kernel_handle_update.cpp +++ b/test/conformance/exp_command_buffer/update/kernel_handle_update.cpp @@ -473,3 +473,30 @@ TEST_P(urCommandBufferValidUpdateParametersTest, UpdateOnlyLocalWorkSize) { ASSERT_NO_FATAL_FAILURE(SaxpyKernel->validate()); } + +// Tests that passing nullptr to hNewKernel works. +TEST_P(urCommandBufferValidUpdateParametersTest, SuccessNullptrHandle) { + + std::vector KernelAlternatives = { + FillUSM2DKernel->Kernel}; + + uur::raii::CommandBufferCommand CommandHandle; + ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp( + updatable_cmd_buf_handle, SaxpyKernel->Kernel, SaxpyKernel->NDimensions, + &(SaxpyKernel->GlobalOffset), &(SaxpyKernel->GlobalSize), + &(SaxpyKernel->LocalSize), KernelAlternatives.size(), + KernelAlternatives.data(), 0, nullptr, 0, nullptr, nullptr, nullptr, + CommandHandle.ptr())); + ASSERT_NE(CommandHandle, nullptr); + + ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); + + SaxpyKernel->UpdateDesc.hNewKernel = nullptr; + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + CommandHandle, &SaxpyKernel->UpdateDesc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ASSERT_NO_FATAL_FAILURE(SaxpyKernel->validate()); +}