Skip to content

Commit a50acd0

Browse files
authored
[NFC][UR][L0] Eliminate code duplication in kernel launch logic (#19562)
This PR removes duplicated code between regular and cooperative kernel launches. The only actual difference between these paths is the use of either `zeCommandListAppendLaunchKernel` or `zeCommandListAppendLaunchCooperativeKernel`. This simplifies maintenance by consolidating identical code and reduces the risk of divergence.
1 parent f9b20bf commit a50acd0

File tree

1 file changed

+7
-262
lines changed
  • unified-runtime/source/adapters/level_zero

1 file changed

+7
-262
lines changed

unified-runtime/source/adapters/level_zero/kernel.cpp

Lines changed: 7 additions & 262 deletions
Original file line numberDiff line numberDiff line change
@@ -56,263 +56,6 @@ ur_result_t urKernelGetSuggestedLocalWorkSize(
5656
return UR_RESULT_SUCCESS;
5757
}
5858

59-
inline ur_result_t EnqueueCooperativeKernelLaunchHelper(
60-
/// [in] handle of the queue object
61-
ur_queue_handle_t Queue,
62-
/// [in] handle of the kernel object
63-
ur_kernel_handle_t Kernel,
64-
/// [in] number of dimensions, from 1 to 3, to specify the global and
65-
/// work-group work-items
66-
uint32_t WorkDim,
67-
/// [in][optional] pointer to an array of workDim unsigned values that
68-
/// specify the offset used to calculate the global ID of a work-item
69-
const size_t *GlobalWorkOffset,
70-
/// [in] pointer to an array of workDim unsigned values that specify the
71-
/// number of global work-items in workDim that will execute the kernel
72-
/// function
73-
const size_t *GlobalWorkSize,
74-
/// [in][optional] pointer to an array of workDim unsigned values that
75-
/// specify the number of local work-items forming a work-group that
76-
/// will execute the kernel function. If nullptr, the runtime
77-
/// implementation will choose the work-group size.
78-
const size_t *LocalWorkSize,
79-
/// [in] size of the event wait list
80-
uint32_t NumEventsInWaitList,
81-
/// [in][optional][range(0, numEventsInWaitList)] pointer to a list of
82-
/// events that must be complete before the kernel execution. If
83-
/// nullptr, the numEventsInWaitList must be 0, indicating that no wait
84-
/// event.
85-
const ur_event_handle_t *EventWaitList,
86-
/// [in,out][optional] return an event object that identifies this
87-
/// particular kernel execution instance.
88-
ur_event_handle_t *OutEvent) {
89-
UR_ASSERT(WorkDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
90-
UR_ASSERT(WorkDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
91-
92-
auto ZeDevice = Queue->Device->ZeDevice;
93-
94-
ze_kernel_handle_t ZeKernel{};
95-
if (Kernel->ZeKernelMap.empty()) {
96-
ZeKernel = Kernel->ZeKernel;
97-
} else {
98-
auto It = Kernel->ZeKernelMap.find(ZeDevice);
99-
if (It == Kernel->ZeKernelMap.end()) {
100-
/* kernel and queue don't match */
101-
return UR_RESULT_ERROR_INVALID_QUEUE;
102-
}
103-
ZeKernel = It->second;
104-
}
105-
// Lock automatically releases when this goes out of scope.
106-
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
107-
Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex);
108-
if (GlobalWorkOffset != NULL) {
109-
UR_CALL(setKernelGlobalOffset(Queue->Context, ZeKernel, WorkDim,
110-
GlobalWorkOffset));
111-
}
112-
113-
// If there are any pending arguments set them now.
114-
for (auto &Arg : Kernel->PendingArguments) {
115-
// The ArgValue may be a NULL pointer in which case a NULL value is used for
116-
// the kernel argument declared as a pointer to global or constant memory.
117-
char **ZeHandlePtr = nullptr;
118-
if (Arg.Value) {
119-
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
120-
Queue->Device, EventWaitList,
121-
NumEventsInWaitList));
122-
}
123-
ZE2UR_CALL(zeKernelSetArgumentValue,
124-
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
125-
}
126-
Kernel->PendingArguments.clear();
127-
128-
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
129-
uint32_t WG[3]{};
130-
131-
// New variable needed because GlobalWorkSize parameter might not be of size 3
132-
size_t GlobalWorkSize3D[3]{1, 1, 1};
133-
std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
134-
135-
if (LocalWorkSize) {
136-
// L0
137-
for (uint32_t I = 0; I < WorkDim; I++) {
138-
UR_ASSERT(LocalWorkSize[I] < (std::numeric_limits<uint32_t>::max)(),
139-
UR_RESULT_ERROR_INVALID_VALUE);
140-
WG[I] = static_cast<uint32_t>(LocalWorkSize[I]);
141-
}
142-
} else {
143-
// We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize
144-
// values do not fit to 32-bit that the API only supports currently.
145-
bool SuggestGroupSize = true;
146-
for (int I : {0, 1, 2}) {
147-
if (GlobalWorkSize3D[I] > UINT32_MAX) {
148-
SuggestGroupSize = false;
149-
}
150-
}
151-
if (SuggestGroupSize) {
152-
ZE2UR_CALL(zeKernelSuggestGroupSize,
153-
(ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1],
154-
GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2]));
155-
} else {
156-
for (int I : {0, 1, 2}) {
157-
// Try to find a I-dimension WG size that the GlobalWorkSize[I] is
158-
// fully divisable with. Start with the max possible size in
159-
// each dimension.
160-
uint32_t GroupSize[] = {
161-
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeX,
162-
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeY,
163-
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeZ};
164-
GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]);
165-
while (GlobalWorkSize3D[I] % GroupSize[I]) {
166-
--GroupSize[I];
167-
}
168-
169-
if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) {
170-
UR_LOG(ERR,
171-
"urEnqueueCooperativeKernelLaunchExp: can't find a WG size "
172-
"suitable for global work size > UINT32_MAX");
173-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
174-
}
175-
WG[I] = GroupSize[I];
176-
}
177-
UR_LOG(DEBUG,
178-
"urEnqueueCooperativeKernelLaunchExp: using computed WG "
179-
"size = {{{}, {}, {}}}",
180-
WG[0], WG[1], WG[2]);
181-
}
182-
}
183-
184-
// TODO: assert if sizes do not fit into 32-bit?
185-
186-
switch (WorkDim) {
187-
case 3:
188-
ZeThreadGroupDimensions.groupCountX =
189-
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
190-
ZeThreadGroupDimensions.groupCountY =
191-
static_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
192-
ZeThreadGroupDimensions.groupCountZ =
193-
static_cast<uint32_t>(GlobalWorkSize3D[2] / WG[2]);
194-
break;
195-
case 2:
196-
ZeThreadGroupDimensions.groupCountX =
197-
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
198-
ZeThreadGroupDimensions.groupCountY =
199-
static_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
200-
WG[2] = 1;
201-
break;
202-
case 1:
203-
ZeThreadGroupDimensions.groupCountX =
204-
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
205-
WG[1] = WG[2] = 1;
206-
break;
207-
208-
default:
209-
UR_LOG(ERR, "urEnqueueCooperativeKernelLaunchExp: unsupported work_dim");
210-
return UR_RESULT_ERROR_INVALID_VALUE;
211-
}
212-
213-
// Error handling for non-uniform group size case
214-
if (GlobalWorkSize3D[0] !=
215-
size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) {
216-
UR_LOG(ERR,
217-
"urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
218-
"range is not a multiple of the group size in the 1st dimension");
219-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
220-
}
221-
if (GlobalWorkSize3D[1] !=
222-
size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) {
223-
UR_LOG(ERR,
224-
"urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
225-
"range is not a multiple of the group size in the 2nd dimension");
226-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
227-
}
228-
if (GlobalWorkSize3D[2] !=
229-
size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) {
230-
UR_LOG(DEBUG,
231-
"urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
232-
"range is not a multiple of the group size in the 3rd dimension");
233-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
234-
}
235-
236-
ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));
237-
238-
bool UseCopyEngine = false;
239-
ur_ze_event_list_t TmpWaitList;
240-
UR_CALL(TmpWaitList.createAndRetainUrZeEventList(
241-
NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
242-
243-
// Get a new command list to be used on this call
244-
ur_command_list_ptr_t CommandList{};
245-
UR_CALL(Queue->Context->getAvailableCommandList(
246-
Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
247-
true /* AllowBatching */, nullptr /*ForcedCmdQueue*/));
248-
249-
ze_event_handle_t ZeEvent = nullptr;
250-
ur_event_handle_t InternalEvent{};
251-
bool IsInternal = OutEvent == nullptr;
252-
ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
253-
254-
UR_CALL(createEventAndAssociateQueue(Queue, Event, UR_COMMAND_KERNEL_LAUNCH,
255-
CommandList, IsInternal, false));
256-
UR_CALL(setSignalEvent(Queue, UseCopyEngine, &ZeEvent, Event,
257-
NumEventsInWaitList, EventWaitList,
258-
CommandList->second.ZeQueue));
259-
(*Event)->WaitList = TmpWaitList;
260-
261-
// Save the kernel in the event, so that when the event is signalled
262-
// the code can do a urKernelRelease on this kernel.
263-
(*Event)->CommandData = (void *)Kernel;
264-
265-
// Increment the reference count of the Kernel and indicate that the Kernel
266-
// is in use. Once the event has been signalled, the code in
267-
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
268-
// reference count on the kernel, using the kernel saved in CommandData.
269-
UR_CALL(ur::level_zero::urKernelRetain(Kernel));
270-
271-
// Add to list of kernels to be submitted
272-
if (IndirectAccessTrackingEnabled)
273-
Queue->KernelsToBeSubmitted.push_back(Kernel);
274-
275-
if (Queue->UsingImmCmdLists && IndirectAccessTrackingEnabled) {
276-
// If using immediate commandlists then gathering of indirect
277-
// references and appending to the queue (which means submission)
278-
// must be done together.
279-
std::unique_lock<ur_shared_mutex> ContextsLock(
280-
Queue->Device->Platform->ContextsMutex, std::defer_lock);
281-
// We are going to submit kernels for execution. If indirect access flag is
282-
// set for a kernel then we need to make a snapshot of existing memory
283-
// allocations in all contexts in the platform. We need to lock the mutex
284-
// guarding the list of contexts in the platform to prevent creation of new
285-
// memory alocations in any context before we submit the kernel for
286-
// execution.
287-
ContextsLock.lock();
288-
Queue->CaptureIndirectAccesses();
289-
// Add the command to the command list, which implies submission.
290-
ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel,
291-
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
292-
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
293-
} else {
294-
// Add the command to the command list for later submission.
295-
// No lock is needed here, unlike the immediate commandlist case above,
296-
// because the kernels are not actually submitted yet. Kernels will be
297-
// submitted only when the comamndlist is closed. Then, a lock is held.
298-
ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel,
299-
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
300-
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
301-
}
302-
303-
UR_LOG(DEBUG,
304-
"calling zeCommandListAppendLaunchCooperativeKernel() with ZeEvent {}",
305-
ur_cast<std::uintptr_t>(ZeEvent));
306-
printZeEventList((*Event)->WaitList);
307-
308-
// Execute command list asynchronously, as the event will be used
309-
// to track down its completion.
310-
UR_CALL(Queue->executeCommandList(CommandList, false /*IsBlocking*/,
311-
true /*OKToBatchCommand*/));
312-
313-
return UR_RESULT_SUCCESS;
314-
}
315-
31659
ur_result_t urEnqueueKernelLaunch(
31760
/// [in] handle of the queue object
31861
ur_queue_handle_t Queue,
@@ -348,14 +91,16 @@ ur_result_t urEnqueueKernelLaunch(
34891
/// [in,out][optional] return an event object that identifies this
34992
/// particular kernel execution instance.
35093
ur_event_handle_t *OutEvent) {
94+
using ZeKernelLaunchFuncT = ze_result_t (*)(
95+
ze_command_list_handle_t, ze_kernel_handle_t, const ze_group_count_t *,
96+
ze_event_handle_t, uint32_t, ze_event_handle_t *);
97+
ZeKernelLaunchFuncT ZeKernelLaunchFunc = &zeCommandListAppendLaunchKernel;
35198
for (uint32_t PropIndex = 0; PropIndex < NumPropsInLaunchPropList;
35299
PropIndex++) {
353100
if (LaunchPropList[PropIndex].id ==
354101
UR_KERNEL_LAUNCH_PROPERTY_ID_COOPERATIVE &&
355102
LaunchPropList[PropIndex].value.cooperative) {
356-
return EnqueueCooperativeKernelLaunchHelper(
357-
Queue, Kernel, WorkDim, GlobalWorkOffset, GlobalWorkSize,
358-
LocalWorkSize, NumEventsInWaitList, EventWaitList, OutEvent);
103+
ZeKernelLaunchFunc = &zeCommandListAppendLaunchCooperativeKernel;
359104
}
360105
if (LaunchPropList[PropIndex].id != UR_KERNEL_LAUNCH_PROPERTY_ID_IGNORE &&
361106
LaunchPropList[PropIndex].id !=
@@ -454,15 +199,15 @@ ur_result_t urEnqueueKernelLaunch(
454199
ContextsLock.lock();
455200
Queue->CaptureIndirectAccesses();
456201
// Add the command to the command list, which implies submission.
457-
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
202+
ZE2UR_CALL(ZeKernelLaunchFunc,
458203
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
459204
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
460205
} else {
461206
// Add the command to the command list for later submission.
462207
// No lock is needed here, unlike the immediate commandlist case above,
463208
// because the kernels are not actually submitted yet. Kernels will be
464209
// submitted only when the comamndlist is closed. Then, a lock is held.
465-
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
210+
ZE2UR_CALL(ZeKernelLaunchFunc,
466211
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
467212
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
468213
}

0 commit comments

Comments
 (0)