Skip to content

Commit 7c58060

Browse files
authored
Merge pull request #1135 from nrspruit/multi_device_kernel_compilation_main
[L0] Add support for multi-device kernel compilation
2 parents 4f80080 + 1b2cd5b commit 7c58060

File tree

5 files changed

+228
-149
lines changed

5 files changed

+228
-149
lines changed

source/adapters/level_zero/kernel.cpp

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
4141
*OutEvent ///< [in,out][optional] return an event object that identifies
4242
///< this particular kernel execution instance.
4343
) {
44+
auto ZeDevice = Queue->Device->ZeDevice;
45+
46+
ze_kernel_handle_t ZeKernel{};
47+
if (Kernel->ZeKernelMap.empty()) {
48+
ZeKernel = Kernel->ZeKernel;
49+
} else {
50+
auto It = Kernel->ZeKernelMap.find(ZeDevice);
51+
ZeKernel = It->second;
52+
}
4453
// Lock automatically releases when this goes out of scope.
4554
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
4655
Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex);
@@ -51,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
5160
}
5261

5362
ZE2UR_CALL(zeKernelSetGlobalOffsetExp,
54-
(Kernel->ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
63+
(ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
5564
GlobalWorkOffset[2]));
5665
}
5766

@@ -65,18 +74,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6574
Queue->Device));
6675
}
6776
ZE2UR_CALL(zeKernelSetArgumentValue,
68-
(Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
77+
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
6978
}
7079
Kernel->PendingArguments.clear();
7180

7281
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
7382
uint32_t WG[3]{};
7483

7584
// global_work_size of unused dimensions must be set to 1
76-
UR_ASSERT(WorkDim == 3 || GlobalWorkSize[2] == 1,
77-
UR_RESULT_ERROR_INVALID_VALUE);
78-
UR_ASSERT(WorkDim >= 2 || GlobalWorkSize[1] == 1,
79-
UR_RESULT_ERROR_INVALID_VALUE);
85+
if (WorkDim >= 2) {
86+
UR_ASSERT(WorkDim >= 2 || GlobalWorkSize[1] == 1,
87+
UR_RESULT_ERROR_INVALID_VALUE);
88+
if (WorkDim == 3) {
89+
UR_ASSERT(WorkDim == 3 || GlobalWorkSize[2] == 1,
90+
UR_RESULT_ERROR_INVALID_VALUE);
91+
}
92+
}
8093
if (LocalWorkSize) {
8194
// L0
8295
UR_ASSERT(LocalWorkSize[0] < (std::numeric_limits<uint32_t>::max)(),
@@ -99,7 +112,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
99112
}
100113
if (SuggestGroupSize) {
101114
ZE2UR_CALL(zeKernelSuggestGroupSize,
102-
(Kernel->ZeKernel, GlobalWorkSize[0], GlobalWorkSize[1],
115+
(ZeKernel, GlobalWorkSize[0], GlobalWorkSize[1],
103116
GlobalWorkSize[2], &WG[0], &WG[1], &WG[2]));
104117
} else {
105118
for (int I : {0, 1, 2}) {
@@ -175,7 +188,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
175188
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
176189
}
177190

178-
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
191+
ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));
179192

180193
bool UseCopyEngine = false;
181194
_ur_ze_event_list_t TmpWaitList;
@@ -227,18 +240,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
227240
Queue->CaptureIndirectAccesses();
228241
// Add the command to the command list, which implies submission.
229242
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
230-
(CommandList->first, Kernel->ZeKernel, &ZeThreadGroupDimensions,
231-
ZeEvent, (*Event)->WaitList.Length,
232-
(*Event)->WaitList.ZeEventList));
243+
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
244+
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
233245
} else {
234246
// Add the command to the command list for later submission.
235247
// No lock is needed here, unlike the immediate commandlist case above,
236248
// because the kernels are not actually submitted yet. Kernels will be
237249
// submitted only when the comamndlist is closed. Then, a lock is held.
238250
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
239-
(CommandList->first, Kernel->ZeKernel, &ZeThreadGroupDimensions,
240-
ZeEvent, (*Event)->WaitList.Length,
241-
(*Event)->WaitList.ZeEventList));
251+
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
252+
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
242253
}
243254

244255
urPrint("calling zeCommandListAppendLaunchKernel() with"
@@ -363,23 +374,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
363374
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
364375
}
365376

366-
ZeStruct<ze_kernel_desc_t> ZeKernelDesc;
367-
ZeKernelDesc.flags = 0;
368-
ZeKernelDesc.pKernelName = KernelName;
369-
370-
ze_kernel_handle_t ZeKernel;
371-
ZE2UR_CALL(zeKernelCreate, (Program->ZeModule, &ZeKernelDesc, &ZeKernel));
372-
373377
try {
374-
ur_kernel_handle_t_ *UrKernel =
375-
new ur_kernel_handle_t_(ZeKernel, true, Program);
378+
ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program);
376379
*RetKernel = reinterpret_cast<ur_kernel_handle_t>(UrKernel);
377380
} catch (const std::bad_alloc &) {
378381
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
379382
} catch (...) {
380383
return UR_RESULT_ERROR_UNKNOWN;
381384
}
382385

386+
for (auto It : Program->ZeModuleMap) {
387+
auto ZeModule = It.second;
388+
ZeStruct<ze_kernel_desc_t> ZeKernelDesc;
389+
ZeKernelDesc.flags = 0;
390+
ZeKernelDesc.pKernelName = KernelName;
391+
392+
ze_kernel_handle_t ZeKernel;
393+
ZE2UR_CALL(zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
394+
395+
auto ZeDevice = It.first;
396+
397+
// Store the kernel in the ZeKernelMap so the correct
398+
// kernel can be retrieved later for a specific device
399+
// where a queue is being submitted.
400+
(*RetKernel)->ZeKernelMap[ZeDevice] = ZeKernel;
401+
(*RetKernel)->ZeKernels.push_back(ZeKernel);
402+
403+
// If the device used to create the module's kernel is a root-device
404+
// then store the kernel also using the sub-devices, since application
405+
// could submit the root-device's kernel to a sub-device's queue.
406+
uint32_t SubDevicesCount = 0;
407+
zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, nullptr);
408+
std::vector<ze_device_handle_t> ZeSubDevices(SubDevicesCount);
409+
zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, ZeSubDevices.data());
410+
for (auto ZeSubDevice : ZeSubDevices) {
411+
(*RetKernel)->ZeKernelMap[ZeSubDevice] = ZeKernel;
412+
}
413+
}
414+
415+
(*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap.begin()->second;
416+
383417
UR_CALL((*RetKernel)->initialize());
384418

385419
return UR_RESULT_SUCCESS;
@@ -396,6 +430,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
396430
) {
397431
std::ignore = Properties;
398432

433+
UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
434+
399435
// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
400436
// in which case a NULL value will be used as the value for the argument
401437
// declared as a pointer to global or constant memory in the kernel"
@@ -409,8 +445,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
409445
}
410446

411447
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
412-
ZE2UR_CALL(zeKernelSetArgumentValue,
413-
(Kernel->ZeKernel, ArgIndex, ArgSize, PArgValue));
448+
for (auto It : Kernel->ZeKernelMap) {
449+
auto ZeKernel = It.second;
450+
ZE2UR_CALL(zeKernelSetArgumentValue,
451+
(ZeKernel, ArgIndex, ArgSize, PArgValue));
452+
}
414453

415454
return UR_RESULT_SUCCESS;
416455
}
@@ -596,16 +635,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
596635

597636
auto KernelProgram = Kernel->Program;
598637
if (Kernel->OwnNativeHandle) {
599-
auto ZeResult = ZE_CALL_NOCHECK(zeKernelDestroy, (Kernel->ZeKernel));
600-
// Gracefully handle the case that L0 was already unloaded.
601-
if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
602-
return ze2urResult(ZeResult);
638+
for (auto &ZeKernel : Kernel->ZeKernels) {
639+
auto ZeResult = ZE_CALL_NOCHECK(zeKernelDestroy, (ZeKernel));
640+
// Gracefully handle the case that L0 was already unloaded.
641+
if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
642+
return ze2urResult(ZeResult);
643+
}
603644
}
645+
Kernel->ZeKernelMap.clear();
604646
if (IndirectAccessTrackingEnabled) {
605647
UR_CALL(urContextRelease(KernelProgram->Context));
606648
}
607-
// do a release on the program this kernel was part of
608-
UR_CALL(urProgramRelease(KernelProgram));
649+
// do a release on the program this kernel was part of without delete of the
650+
// program handle
651+
KernelProgram->ur_release_program_resources(false);
652+
609653
delete Kernel;
610654

611655
return UR_RESULT_SUCCESS;
@@ -639,6 +683,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
639683
std::ignore = PropSize;
640684
std::ignore = Properties;
641685

686+
auto ZeKernel = Kernel->ZeKernel;
642687
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
643688
if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS &&
644689
*(static_cast<const ur_bool_t *>(PropValue)) == true) {
@@ -649,7 +694,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
649694
ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST |
650695
ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE |
651696
ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
652-
ZE2UR_CALL(zeKernelSetIndirectAccess, (Kernel->ZeKernel, IndirectFlags));
697+
ZE2UR_CALL(zeKernelSetIndirectAccess, (ZeKernel, IndirectFlags));
653698
} else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) {
654699
ze_cache_config_flag_t ZeCacheConfig{};
655700
auto CacheConfig =
@@ -663,7 +708,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
663708
else
664709
// Unexpected cache configuration value.
665710
return UR_RESULT_ERROR_INVALID_VALUE;
666-
ZE2UR_CALL(zeKernelSetCacheConfig, (Kernel->ZeKernel, ZeCacheConfig););
711+
ZE2UR_CALL(zeKernelSetCacheConfig, (ZeKernel, ZeCacheConfig););
667712
} else {
668713
urPrint("urKernelSetExecInfo: unsupported ParamName\n");
669714
return UR_RESULT_ERROR_INVALID_VALUE;

source/adapters/level_zero/kernel.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414
#include <unordered_set>
1515

1616
struct ur_kernel_handle_t_ : _ur_object {
17-
ur_kernel_handle_t_(ze_kernel_handle_t Kernel, bool OwnZeHandle,
18-
ur_program_handle_t Program)
19-
: Context{nullptr}, Program{Program}, ZeKernel{Kernel},
20-
SubmissionsCount{0}, MemAllocs{} {
17+
ur_kernel_handle_t_(bool OwnZeHandle, ur_program_handle_t Program)
18+
: Program{Program}, SubmissionsCount{0}, MemAllocs{} {
2119
OwnNativeHandle = OwnZeHandle;
2220
}
2321

@@ -37,6 +35,15 @@ struct ur_kernel_handle_t_ : _ur_object {
3735
// Level Zero function handle.
3836
ze_kernel_handle_t ZeKernel;
3937

38+
// Map of L0 kernels created for all the devices for which a UR Program
39+
// has been built. It may contain duplicated kernel entries for a root
40+
// device and its sub-devices.
41+
std::unordered_map<ze_device_handle_t, ze_kernel_handle_t> ZeKernelMap;
42+
43+
// Vector of L0 kernels. Each entry is unique, so this is used for
44+
// destroying the kernels instead of ZeKernelMap
45+
std::vector<ze_kernel_handle_t> ZeKernels;
46+
4047
// Counter to track the number of submissions of the kernel.
4148
// When this value is zero, it means that kernel is not submitted for an
4249
// execution - at this time we can release memory allocations referenced by

0 commit comments

Comments
 (0)