Skip to content

Commit 0b0295c

Browse files
author
Ewan Crawford
committed
Fix L0 command-buffer consumption of multi-device kernels
UR program and kernel objects can be tied to multiple devices, a UR command-buffer object however is tied to a single device. When appending a kernel command to a command-buffer, select the correct single-device ze_kernel_handle_t object from the multi-device ur_kernel_handle_t object
1 parent 098deca commit 0b0295c

File tree

1 file changed

+48
-25
lines changed

1 file changed

+48
-25
lines changed

source/adapters/level_zero/command_buffer.cpp

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -894,28 +894,31 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
894894
/**
895895
* Sets the kernel arguments for a kernel command that will be appended to the
896896
* command buffer.
897-
* @param[in] CommandBuffer The CommandBuffer where the command will be
897+
* @param[in] Device The Device associated with the command-buffer where the
898+
* kernel command will be appended.
899+
* @param[in,out] Arguments stored in the ur_kernel_handle_t object to be set
900+
* on the /p ZeKernel object.
901+
* @param[in] ZeKernel The handle to the Level-Zero kernel that will be
898902
* appended.
899-
* @param[in] Kernel The handle to the kernel that will be appended.
900903
* @return UR_RESULT_SUCCESS or an error code on failure
901904
*/
902-
ur_result_t
903-
setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
904-
ur_kernel_handle_t Kernel) {
905-
905+
ur_result_t setKernelPendingArguments(
906+
ur_device_handle_t Device,
907+
std::vector<ur_kernel_handle_t_::ArgumentInfo> &PendingArguments,
908+
ze_kernel_handle_t ZeKernel) {
906909
// If there are any pending arguments set them now.
907-
for (auto &Arg : Kernel->PendingArguments) {
910+
for (auto &Arg : PendingArguments) {
908911
// The ArgValue may be a NULL pointer in which case a NULL value is used for
909912
// the kernel argument declared as a pointer to global or constant memory.
910913
char **ZeHandlePtr = nullptr;
911914
if (Arg.Value) {
912-
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
913-
CommandBuffer->Device, nullptr, 0u));
915+
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
916+
nullptr, 0u));
914917
}
915918
ZE2UR_CALL(zeKernelSetArgumentValue,
916-
(Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
919+
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
917920
}
918-
Kernel->PendingArguments.clear();
921+
PendingArguments.clear();
919922

920923
return UR_RESULT_SUCCESS;
921924
}
@@ -951,21 +954,29 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
951954
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
952955

953956
auto Platform = CommandBuffer->Context->getPlatform();
957+
auto ZeDevice = CommandBuffer->Device->ZeDevice;
958+
954959
if (NumKernelAlternatives > 0) {
955960
ZeMutableCommandDesc.flags |=
956961
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
957962

958963
std::vector<ze_kernel_handle_t> TranslatedKernelHandles(
959964
NumKernelAlternatives + 1, nullptr);
960965

966+
ze_kernel_handle_t ZeMainKernel{};
967+
UR_CALL(getZeKernel(ZeDevice, Kernel, &ZeMainKernel));
968+
961969
// Translate main kernel first
962970
ZE2UR_CALL(zelLoaderTranslateHandle,
963-
(ZEL_HANDLE_KERNEL, Kernel->ZeKernel,
971+
(ZEL_HANDLE_KERNEL, ZeMainKernel,
964972
(void **)&TranslatedKernelHandles[0]));
965973

966974
for (size_t i = 0; i < NumKernelAlternatives; i++) {
975+
ze_kernel_handle_t ZeAltKernel{};
976+
UR_CALL(getZeKernel(ZeDevice, KernelAlternatives[i], &ZeAltKernel));
977+
967978
ZE2UR_CALL(zelLoaderTranslateHandle,
968-
(ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel,
979+
(ZEL_HANDLE_KERNEL, ZeAltKernel,
969980
(void **)&TranslatedKernelHandles[i + 1]));
970981
}
971982

@@ -1022,23 +1033,28 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10221033
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
10231034
Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex);
10241035

1036+
auto Device = CommandBuffer->Device;
1037+
ze_kernel_handle_t ZeKernel{};
1038+
UR_CALL(getZeKernel(Device->ZeDevice, Kernel, &ZeKernel));
1039+
10251040
if (GlobalWorkOffset != NULL) {
1026-
UR_CALL(setKernelGlobalOffset(CommandBuffer->Context, Kernel->ZeKernel,
1027-
WorkDim, GlobalWorkOffset));
1041+
UR_CALL(setKernelGlobalOffset(CommandBuffer->Context, ZeKernel, WorkDim,
1042+
GlobalWorkOffset));
10281043
}
10291044

10301045
// If there are any pending arguments set them now.
10311046
if (!Kernel->PendingArguments.empty()) {
1032-
UR_CALL(setKernelPendingArguments(CommandBuffer, Kernel));
1047+
UR_CALL(
1048+
setKernelPendingArguments(Device, Kernel->PendingArguments, ZeKernel));
10331049
}
10341050

10351051
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
10361052
uint32_t WG[3];
1037-
UR_CALL(calculateKernelWorkDimensions(Kernel->ZeKernel, CommandBuffer->Device,
1053+
UR_CALL(calculateKernelWorkDimensions(ZeKernel, Device,
10381054
ZeThreadGroupDimensions, WG, WorkDim,
10391055
GlobalWorkSize, LocalWorkSize));
10401056

1041-
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
1057+
ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));
10421058

10431059
CommandBuffer->KernelsList.push_back(Kernel);
10441060
for (size_t i = 0; i < NumKernelAlternatives; i++) {
@@ -1063,7 +1079,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
10631079
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
10641080

10651081
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
1066-
(CommandBuffer->ZeComputeCommandList, Kernel->ZeKernel,
1082+
(CommandBuffer->ZeComputeCommandList, ZeKernel,
10671083
&ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size(),
10681084
getPointerFromVector(ZeEventList)));
10691085

@@ -1836,6 +1852,7 @@ ur_result_t updateKernelCommand(
18361852
const auto CommandBuffer = Command->CommandBuffer;
18371853
const void *NextDesc = nullptr;
18381854
auto Platform = CommandBuffer->Context->getPlatform();
1855+
auto ZeDevice = CommandBuffer->Device->ZeDevice;
18391856

18401857
uint32_t Dim = CommandDesc->newWorkDim;
18411858
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset;
@@ -1844,11 +1861,14 @@ ur_result_t updateKernelCommand(
18441861

18451862
// Kernel handle must be updated first for a given CommandId if required
18461863
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel;
1864+
18471865
if (NewKernel && Command->Kernel != NewKernel) {
1866+
ze_kernel_handle_t ZeNewKernel{};
1867+
UR_CALL(getZeKernel(ZeDevice, NewKernel, &ZeNewKernel));
1868+
18481869
ze_kernel_handle_t ZeKernelTranslated = nullptr;
1849-
ZE2UR_CALL(
1850-
zelLoaderTranslateHandle,
1851-
(ZEL_HANDLE_KERNEL, NewKernel->ZeKernel, (void **)&ZeKernelTranslated));
1870+
ZE2UR_CALL(zelLoaderTranslateHandle,
1871+
(ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
18521872

18531873
ZE2UR_CALL(Platform->ZeMutableCmdListExt
18541874
.zexCommandListUpdateMutableCommandKernelsExp,
@@ -1905,10 +1925,13 @@ ur_result_t updateKernelCommand(
19051925
// by the driver for the kernel.
19061926
bool UpdateWGSize = NewLocalWorkSize == nullptr;
19071927

1928+
ze_kernel_handle_t ZeKernel{};
1929+
UR_CALL(getZeKernel(ZeDevice, Command->Kernel, &ZeKernel));
1930+
19081931
uint32_t WG[3];
1909-
UR_CALL(calculateKernelWorkDimensions(
1910-
Command->Kernel->ZeKernel, CommandBuffer->Device,
1911-
ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize));
1932+
UR_CALL(calculateKernelWorkDimensions(ZeKernel, CommandBuffer->Device,
1933+
ZeThreadGroupDimensions, WG, Dim,
1934+
NewGlobalWorkSize, NewLocalWorkSize));
19121935

19131936
auto MutableGroupCountDesc =
19141937
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t>>();

0 commit comments

Comments
 (0)