Skip to content

Commit 8868230

Browse files
committed
[L0] Add support for multi-device kernel compilation
Signed-off-by: Spruit, Neil R <[email protected]>
1 parent 92f44da commit 8868230

File tree

5 files changed

+195
-113
lines changed

5 files changed

+195
-113
lines changed

source/adapters/level_zero/kernel.cpp

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
4141
*OutEvent ///< [in,out][optional] return an event object that identifies
4242
///< this particular kernel execution instance.
4343
) {
44+
auto ZeDevice = Queue->Device->ZeDevice;
45+
46+
ze_kernel_handle_t ZeKernel{};
47+
if (Kernel->ZeKernelMap.empty()) {
48+
ZeKernel = Kernel->ZeKernel;
49+
} else {
50+
auto It = Kernel->ZeKernelMap.find(ZeDevice);
51+
ZeKernel = It->second;
52+
}
4453
// Lock automatically releases when this goes out of scope.
4554
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
4655
Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex);
@@ -51,7 +60,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
5160
}
5261

5362
ZE2UR_CALL(zeKernelSetGlobalOffsetExp,
54-
(Kernel->ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
63+
(ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
5564
GlobalWorkOffset[2]));
5665
}
5766

@@ -65,7 +74,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6574
Queue->Device));
6675
}
6776
ZE2UR_CALL(zeKernelSetArgumentValue,
68-
(Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
77+
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
6978
}
7079
Kernel->PendingArguments.clear();
7180

@@ -99,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
99108
}
100109
if (SuggestGroupSize) {
101110
ZE2UR_CALL(zeKernelSuggestGroupSize,
102-
(Kernel->ZeKernel, GlobalWorkSize[0], GlobalWorkSize[1],
111+
(ZeKernel, GlobalWorkSize[0], GlobalWorkSize[1],
103112
GlobalWorkSize[2], &WG[0], &WG[1], &WG[2]));
104113
} else {
105114
for (int I : {0, 1, 2}) {
@@ -175,7 +184,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
175184
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
176185
}
177186

178-
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
187+
ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));
179188

180189
bool UseCopyEngine = false;
181190
_ur_ze_event_list_t TmpWaitList;
@@ -227,18 +236,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
227236
Queue->CaptureIndirectAccesses();
228237
// Add the command to the command list, which implies submission.
229238
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
230-
(CommandList->first, Kernel->ZeKernel, &ZeThreadGroupDimensions,
231-
ZeEvent, (*Event)->WaitList.Length,
232-
(*Event)->WaitList.ZeEventList));
239+
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
240+
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
233241
} else {
234242
// Add the command to the command list for later submission.
235243
// No lock is needed here, unlike the immediate commandlist case above,
236244
// because the kernels are not actually submitted yet. Kernels will be
237245
// submitted only when the comamndlist is closed. Then, a lock is held.
238246
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
239-
(CommandList->first, Kernel->ZeKernel, &ZeThreadGroupDimensions,
240-
ZeEvent, (*Event)->WaitList.Length,
241-
(*Event)->WaitList.ZeEventList));
247+
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
248+
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
242249
}
243250

244251
urPrint("calling zeCommandListAppendLaunchKernel() with"
@@ -363,23 +370,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
363370
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
364371
}
365372

366-
ZeStruct<ze_kernel_desc_t> ZeKernelDesc;
367-
ZeKernelDesc.flags = 0;
368-
ZeKernelDesc.pKernelName = KernelName;
369-
370-
ze_kernel_handle_t ZeKernel;
371-
ZE2UR_CALL(zeKernelCreate, (Program->ZeModule, &ZeKernelDesc, &ZeKernel));
372-
373373
try {
374-
ur_kernel_handle_t_ *UrKernel =
375-
new ur_kernel_handle_t_(ZeKernel, true, Program);
374+
ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program);
376375
*RetKernel = reinterpret_cast<ur_kernel_handle_t>(UrKernel);
377376
} catch (const std::bad_alloc &) {
378377
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
379378
} catch (...) {
380379
return UR_RESULT_ERROR_UNKNOWN;
381380
}
382381

382+
for (auto It : Program->ZeModuleMap) {
383+
auto ZeModule = It.second;
384+
ZeStruct<ze_kernel_desc_t> ZeKernelDesc;
385+
ZeKernelDesc.flags = 0;
386+
ZeKernelDesc.pKernelName = KernelName;
387+
388+
ze_kernel_handle_t ZeKernel;
389+
ZE2UR_CALL(zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
390+
391+
auto ZeDevice = It.first;
392+
393+
// Store the kernel in the ZeKernelMap so the correct
394+
// kernel can be retrieved later for a specific device
395+
// where a queue is being submitted.
396+
(*RetKernel)->ZeKernelMap[ZeDevice] = ZeKernel;
397+
(*RetKernel)->ZeKernels.push_back(ZeKernel);
398+
399+
// If the device used to create the module's kernel is a root-device
400+
// then store the kernel also using the sub-devices, since application
401+
// could submit the root-device's kernel to a sub-device's queue.
402+
uint32_t SubDevicesCount = 0;
403+
zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, nullptr);
404+
std::vector<ze_device_handle_t> ZeSubDevices(SubDevicesCount);
405+
zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, ZeSubDevices.data());
406+
for (auto ZeSubDevice : ZeSubDevices) {
407+
(*RetKernel)->ZeKernelMap[ZeSubDevice] = ZeKernel;
408+
}
409+
}
410+
411+
(*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap.begin()->second;
412+
383413
UR_CALL((*RetKernel)->initialize());
384414

385415
return UR_RESULT_SUCCESS;
@@ -396,6 +426,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
396426
) {
397427
std::ignore = Properties;
398428

429+
UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
430+
399431
// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
400432
// in which case a NULL value will be used as the value for the argument
401433
// declared as a pointer to global or constant memory in the kernel"
@@ -409,8 +441,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
409441
}
410442

411443
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
412-
ZE2UR_CALL(zeKernelSetArgumentValue,
413-
(Kernel->ZeKernel, ArgIndex, ArgSize, PArgValue));
444+
for (auto It : Kernel->ZeKernelMap) {
445+
auto ZeKernel = It.second;
446+
ZE2UR_CALL(zeKernelSetArgumentValue,
447+
(ZeKernel, ArgIndex, ArgSize, PArgValue));
448+
}
414449

415450
return UR_RESULT_SUCCESS;
416451
}
@@ -596,11 +631,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
596631

597632
auto KernelProgram = Kernel->Program;
598633
if (Kernel->OwnNativeHandle) {
599-
auto ZeResult = ZE_CALL_NOCHECK(zeKernelDestroy, (Kernel->ZeKernel));
600-
// Gracefully handle the case that L0 was already unloaded.
601-
if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
602-
return ze2urResult(ZeResult);
634+
for (auto &ZeKernel : Kernel->ZeKernels) {
635+
auto ZeResult = ZE_CALL_NOCHECK(zeKernelDestroy, (ZeKernel));
636+
// Gracefully handle the case that L0 was already unloaded.
637+
if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
638+
return ze2urResult(ZeResult);
639+
}
603640
}
641+
Kernel->ZeKernelMap.clear();
604642
if (IndirectAccessTrackingEnabled) {
605643
UR_CALL(urContextRelease(KernelProgram->Context));
606644
}
@@ -639,6 +677,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
639677
std::ignore = PropSize;
640678
std::ignore = Properties;
641679

680+
auto ZeKernel = Kernel->ZeKernel;
642681
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
643682
if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS &&
644683
*(static_cast<const ur_bool_t *>(PropValue)) == true) {
@@ -649,7 +688,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
649688
ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST |
650689
ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE |
651690
ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
652-
ZE2UR_CALL(zeKernelSetIndirectAccess, (Kernel->ZeKernel, IndirectFlags));
691+
ZE2UR_CALL(zeKernelSetIndirectAccess, (ZeKernel, IndirectFlags));
653692
} else if (PropName == UR_KERNEL_EXEC_INFO_CACHE_CONFIG) {
654693
ze_cache_config_flag_t ZeCacheConfig{};
655694
auto CacheConfig =
@@ -663,7 +702,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo(
663702
else
664703
// Unexpected cache configuration value.
665704
return UR_RESULT_ERROR_INVALID_VALUE;
666-
ZE2UR_CALL(zeKernelSetCacheConfig, (Kernel->ZeKernel, ZeCacheConfig););
705+
ZE2UR_CALL(zeKernelSetCacheConfig, (ZeKernel, ZeCacheConfig););
667706
} else {
668707
urPrint("urKernelSetExecInfo: unsupported ParamName\n");
669708
return UR_RESULT_ERROR_INVALID_VALUE;

source/adapters/level_zero/kernel.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414
#include <unordered_set>
1515

1616
struct ur_kernel_handle_t_ : _ur_object {
17-
ur_kernel_handle_t_(ze_kernel_handle_t Kernel, bool OwnZeHandle,
18-
ur_program_handle_t Program)
19-
: Context{nullptr}, Program{Program}, ZeKernel{Kernel},
20-
SubmissionsCount{0}, MemAllocs{} {
17+
ur_kernel_handle_t_(bool OwnZeHandle, ur_program_handle_t Program)
18+
: Program{Program}, SubmissionsCount{0}, MemAllocs{} {
2119
OwnNativeHandle = OwnZeHandle;
2220
}
2321

@@ -37,6 +35,15 @@ struct ur_kernel_handle_t_ : _ur_object {
3735
// Level Zero function handle.
3836
ze_kernel_handle_t ZeKernel;
3937

38+
// Map of L0 kernels created for all the devices for which a UR Program
39+
// has been built. It may contain duplicated kernel entries for a root
40+
// device and its sub-devices.
41+
std::unordered_map<ze_device_handle_t, ze_kernel_handle_t> ZeKernelMap;
42+
43+
// Vector of L0 kernels. Each entry is unique, so this is used for
44+
// destroying the kernels instead of ZeKernelMap
45+
std::vector<ze_kernel_handle_t> ZeKernels;
46+
4047
// Counter to track the number of submissions of the kernel.
4148
// When this value is zero, it means that kernel is not submitted for an
4249
// execution - at this time we can release memory allocations referenced by

0 commit comments

Comments
 (0)