Skip to content

Commit 017d8f3

Browse files
committed
Fix binary update implementation to allow nullptr commmand handles
1 parent d9d24ec commit 017d8f3

File tree

3 files changed

+49
-23
lines changed

3 files changed

+49
-23
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -896,18 +896,19 @@ validateCommandDesc(ur_exp_command_buffer_command_handle_t Command,
896896
}
897897

898898
/**
899-
* Updates the arguments of CommandDesc->hNewKernel
900-
* @param[in] Device The device associated with the kernel being updated.
899+
* Updates the arguments of a kernel command.
900+
* @param[in] Command The command associated with the kernel node being updated.
901901
* @param[in] UpdateCommandDesc The update command description that contains the
902-
* new kernel and its arguments.
902+
* new arguments.
903903
* @return UR_RESULT_SUCCESS or an error code on failure
904904
*/
905905
ur_result_t
906-
updateKernelArguments(ur_device_handle_t Device,
906+
updateKernelArguments(ur_exp_command_buffer_command_handle_t Command,
907907
const ur_exp_command_buffer_update_kernel_launch_desc_t
908908
*UpdateCommandDesc) {
909909

910-
ur_kernel_handle_t NewKernel = UpdateCommandDesc->hNewKernel;
910+
ur_kernel_handle_t Kernel = Command->Kernel;
911+
ur_device_handle_t Device = Command->CommandBuffer->Device;
911912

912913
// Update pointer arguments to the kernel
913914
uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs;
@@ -920,7 +921,7 @@ updateKernelArguments(ur_device_handle_t Device,
920921

921922
ur_result_t Result = UR_RESULT_SUCCESS;
922923
try {
923-
NewKernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue);
924+
Kernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue);
924925
} catch (ur_result_t Err) {
925926
Result = Err;
926927
return Result;
@@ -939,10 +940,10 @@ updateKernelArguments(ur_device_handle_t Device,
939940
ur_result_t Result = UR_RESULT_SUCCESS;
940941
try {
941942
if (ArgValue == nullptr) {
942-
NewKernel->setKernelArg(ArgIndex, 0, nullptr);
943+
Kernel->setKernelArg(ArgIndex, 0, nullptr);
943944
} else {
944945
CUdeviceptr CuPtr = std::get<BufferMem>(ArgValue->Mem).getPtr(Device);
945-
NewKernel->setKernelArg(ArgIndex, sizeof(CUdeviceptr), (void *)&CuPtr);
946+
Kernel->setKernelArg(ArgIndex, sizeof(CUdeviceptr), (void *)&CuPtr);
946947
}
947948
} catch (ur_result_t Err) {
948949
Result = Err;
@@ -962,7 +963,7 @@ updateKernelArguments(ur_device_handle_t Device,
962963

963964
ur_result_t Result = UR_RESULT_SUCCESS;
964965
try {
965-
NewKernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
966+
Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
966967
} catch (ur_result_t Err) {
967968
Result = Err;
968969
return Result;
@@ -1018,9 +1019,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
10181019
ur_exp_command_buffer_handle_t CommandBuffer = hCommand->CommandBuffer;
10191020

10201021
UR_CHECK_ERROR(validateCommandDesc(hCommand, pUpdateKernelLaunch));
1021-
UR_CHECK_ERROR(
1022-
updateKernelArguments(CommandBuffer->Device, pUpdateKernelLaunch));
10231022
UR_CHECK_ERROR(updateCommand(hCommand, pUpdateKernelLaunch));
1023+
UR_CHECK_ERROR(updateKernelArguments(hCommand, pUpdateKernelLaunch));
10241024

10251025
// If no work-size is provided make sure we pass nullptr to setKernelParams so
10261026
// it can guess the local work size.

source/adapters/hip/command_buffer.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -882,18 +882,19 @@ validateCommandDesc(ur_exp_command_buffer_command_handle_t Command,
882882
}
883883

884884
/**
885-
* Updates the arguments of CommandDesc->hNewKernel
886-
* @param[in] Device The device associated with the kernel being updated.
887-
* @param[in] UpdateCommandDesc The update command description that contains
888-
* the new kernel and its arguments.
885+
* Updates the arguments of a kernel command.
886+
* @param[in] Command The command associated with the kernel node being updated.
887+
* @param[in] UpdateCommandDesc The update command description that contains the
888+
* new arguments.
889889
* @return UR_RESULT_SUCCESS or an error code on failure
890890
*/
891891
ur_result_t
892-
updateKernelArguments(ur_device_handle_t Device,
892+
updateKernelArguments(ur_exp_command_buffer_command_handle_t Command,
893893
const ur_exp_command_buffer_update_kernel_launch_desc_t
894894
*UpdateCommandDesc) {
895895

896-
ur_kernel_handle_t NewKernel = UpdateCommandDesc->hNewKernel;
896+
ur_kernel_handle_t Kernel = Command->Kernel;
897+
ur_device_handle_t Device = Command->CommandBuffer->Device;
897898

898899
// Update pointer arguments to the kernel
899900
uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs;
@@ -905,7 +906,7 @@ updateKernelArguments(ur_device_handle_t Device,
905906
const void *ArgValue = PointerArgDesc.pNewPointerArg;
906907

907908
try {
908-
NewKernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue);
909+
Kernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue);
909910
} catch (ur_result_t Err) {
910911
return Err;
911912
}
@@ -922,10 +923,10 @@ updateKernelArguments(ur_device_handle_t Device,
922923

923924
try {
924925
if (ArgValue == nullptr) {
925-
NewKernel->setKernelArg(ArgIndex, 0, nullptr);
926+
Kernel->setKernelArg(ArgIndex, 0, nullptr);
926927
} else {
927928
void *HIPPtr = std::get<BufferMem>(ArgValue->Mem).getVoid(Device);
928-
NewKernel->setKernelArg(ArgIndex, sizeof(void *), (void *)&HIPPtr);
929+
Kernel->setKernelArg(ArgIndex, sizeof(void *), (void *)&HIPPtr);
929930
}
930931
} catch (ur_result_t Err) {
931932
return Err;
@@ -943,7 +944,7 @@ updateKernelArguments(ur_device_handle_t Device,
943944
const void *ArgValue = ValueArgDesc.pNewValueArg;
944945

945946
try {
946-
NewKernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
947+
Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue);
947948
} catch (ur_result_t Err) {
948949
return Err;
949950
}
@@ -998,9 +999,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
998999
ur_exp_command_buffer_handle_t CommandBuffer = hCommand->CommandBuffer;
9991000

10001001
UR_CHECK_ERROR(validateCommandDesc(hCommand, pUpdateKernelLaunch));
1001-
UR_CHECK_ERROR(
1002-
updateKernelArguments(CommandBuffer->Device, pUpdateKernelLaunch));
10031002
UR_CHECK_ERROR(updateCommand(hCommand, pUpdateKernelLaunch));
1003+
UR_CHECK_ERROR(updateKernelArguments(hCommand, pUpdateKernelLaunch));
10041004

10051005
// If no worksize is provided make sure we pass nullptr to setKernelParams
10061006
// so it can guess the local work size.

test/conformance/exp_command_buffer/update/kernel_handle_update.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,29 @@ TEST_P(urCommandBufferValidUpdateParametersTest, UpdateOnlyLocalWorkSize) {
467467

468468
ASSERT_NO_FATAL_FAILURE(SaxpyKernel->validate());
469469
}
470+
471+
// Tests that passing nullptr to hNewKernel works.
472+
TEST_P(urCommandBufferValidUpdateParametersTest, SuccessNullptrHandle) {
473+
474+
std::vector<ur_kernel_handle_t> KernelAlternatives = {
475+
FillUSM2DKernel->Kernel};
476+
477+
uur::raii::CommandBufferCommand CommandHandle;
478+
ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp(
479+
updatable_cmd_buf_handle, SaxpyKernel->Kernel, SaxpyKernel->NDimensions,
480+
&(SaxpyKernel->GlobalOffset), &(SaxpyKernel->GlobalSize),
481+
&(SaxpyKernel->LocalSize), KernelAlternatives.size(),
482+
KernelAlternatives.data(), 0, nullptr, nullptr, CommandHandle.ptr()));
483+
ASSERT_NE(CommandHandle, nullptr);
484+
485+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
486+
487+
SaxpyKernel->UpdateDesc.hNewKernel = nullptr;
488+
ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(
489+
CommandHandle, &SaxpyKernel->UpdateDesc));
490+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
491+
nullptr, nullptr));
492+
ASSERT_SUCCESS(urQueueFinish(queue));
493+
494+
ASSERT_NO_FATAL_FAILURE(SaxpyKernel->validate());
495+
}

0 commit comments

Comments
 (0)