Skip to content

Commit 63c70a1

Browse files
authored
[NativeCPU] Simplify enqueue. (#19550)
We were creating excessive numbers of threads. When we know we want a given amount of threads, just divide the number of workgroups by the number of threads and have each thread process that many workgroups. This implementation also means we no longer need to resize workgroups, which was not generally safe.
1 parent a70552b commit 63c70a1

File tree

1 file changed

+57
-128
lines changed

1 file changed

+57
-128
lines changed

unified-runtime/source/adapters/native_cpu/enqueue.cpp

Lines changed: 57 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,6 @@ struct NDRDescT {
5252
};
5353
} // namespace native_cpu
5454

55-
#ifdef NATIVECPU_USE_OCK
56-
static native_cpu::state getResizedState(const native_cpu::NDRDescT &ndr,
57-
size_t itemsPerThread) {
58-
native_cpu::state resized_state(
59-
ndr.GlobalSize[0], ndr.GlobalSize[1], ndr.GlobalSize[2], itemsPerThread,
60-
ndr.LocalSize[1], ndr.LocalSize[2], ndr.GlobalOffset[0],
61-
ndr.GlobalOffset[1], ndr.GlobalOffset[2]);
62-
return resized_state;
63-
}
64-
#endif
65-
6655
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6756
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
6857
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -112,6 +101,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
112101
// TODO: add proper error checking
113102
native_cpu::NDRDescT ndr(workDim, pGlobalWorkOffset, pGlobalWorkSize,
114103
pLocalWorkSize);
104+
unsigned long long numWI;
105+
auto umulll_overflow = [](unsigned long long a, unsigned long long b,
106+
unsigned long long *c) -> bool {
107+
#ifdef __GNUC__
108+
return __builtin_umulll_overflow(a, b, c);
109+
#else
110+
*c = a * b;
111+
return a != 0 && b != *c / a;
112+
#endif
113+
};
114+
if (umulll_overflow(ndr.GlobalSize[0], ndr.GlobalSize[1], &numWI) ||
115+
umulll_overflow(numWI, ndr.GlobalSize[2], &numWI) || numWI > SIZE_MAX) {
116+
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
117+
}
118+
115119
auto &tp = hQueue->getDevice()->tp;
116120
const size_t numParallelThreads = tp.num_threads();
117121
std::vector<std::future<void>> futures;
@@ -130,131 +134,56 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
130134
auto kernel = std::make_unique<ur_kernel_handle_t_>(*hKernel);
131135
kernel->updateMemPool(numParallelThreads);
132136

137+
const size_t numWG = numWG0 * numWG1 * numWG2;
138+
const size_t numWGPerThread = numWG / numParallelThreads;
139+
const size_t remainderWG = numWG - numWGPerThread * numParallelThreads;
140+
// The fourth value is the linearized value.
141+
std::array<size_t, 4> rangeStart = {0, 0, 0, 0};
142+
for (unsigned t = 0; t < numParallelThreads; ++t) {
143+
auto rangeEnd = rangeStart;
144+
rangeEnd[3] += numWGPerThread + (t < remainderWG);
145+
if (rangeEnd[3] == rangeStart[3])
146+
break;
147+
rangeEnd[0] = rangeEnd[3] % numWG0;
148+
rangeEnd[1] = (rangeEnd[3] / numWG0) % numWG1;
149+
rangeEnd[2] = rangeEnd[3] / (numWG0 * numWG1);
150+
futures.emplace_back(
151+
tp.schedule_task([state, &kernel = *kernel, rangeStart,
152+
rangeEnd = rangeEnd[3], numWG0, numWG1,
133153
#ifndef NATIVECPU_USE_OCK
134-
for (unsigned g2 = 0; g2 < numWG2; g2++) {
135-
for (unsigned g1 = 0; g1 < numWG1; g1++) {
136-
for (unsigned g0 = 0; g0 < numWG0; g0++) {
137-
for (unsigned local2 = 0; local2 < ndr.LocalSize[2]; local2++) {
138-
for (unsigned local1 = 0; local1 < ndr.LocalSize[1]; local1++) {
139-
for (unsigned local0 = 0; local0 < ndr.LocalSize[0]; local0++) {
140-
state.update(g0, g1, g2, local0, local1, local2);
141-
kernel->_subhandler(kernel->getArgs(1, 0).data(), &state);
142-
}
143-
}
144-
}
145-
}
146-
}
147-
}
154+
localSize = ndr.LocalSize,
155+
#endif
156+
numParallelThreads](size_t threadId) mutable {
157+
for (size_t g0 = rangeStart[0], g1 = rangeStart[1],
158+
g2 = rangeStart[2], g3 = rangeStart[3];
159+
g3 < rangeEnd; ++g3) {
160+
#ifdef NATIVECPU_USE_OCK
161+
state.update(g0, g1, g2);
162+
kernel._subhandler(
163+
kernel.getArgs(numParallelThreads, threadId).data(), &state);
148164
#else
149-
bool isLocalSizeOne =
150-
ndr.LocalSize[0] == 1 && ndr.LocalSize[1] == 1 && ndr.LocalSize[2] == 1;
151-
if (isLocalSizeOne && ndr.GlobalSize[0] > numParallelThreads &&
152-
!kernel->hasLocalArgs()) {
153-
// If the local size is one, we make the assumption that we are running a
154-
// parallel_for over a sycl::range.
155-
// Todo: we could add more compiler checks and
156-
// kernel properties for this (e.g. check that no barriers are called).
157-
158-
// Todo: this assumes that dim 0 is the best dimension over which we want to
159-
// parallelize
160-
161-
// Since we also vectorize the kernel, and vectorization happens within the
162-
// work group loop, it's better to have a large-ish local size. We can
163-
// divide the global range by the number of threads, set that as the local
164-
// size and peel everything else.
165-
166-
size_t new_num_work_groups_0 = numParallelThreads;
167-
size_t itemsPerThread = ndr.GlobalSize[0] / numParallelThreads;
168-
169-
for (unsigned g2 = 0; g2 < numWG2; g2++) {
170-
for (unsigned g1 = 0; g1 < numWG1; g1++) {
171-
for (unsigned g0 = 0; g0 < new_num_work_groups_0; g0 += 1) {
172-
futures.emplace_back(tp.schedule_task(
173-
[ndr, itemsPerThread, &kernel = *kernel, g0, g1, g2](size_t) {
174-
native_cpu::state resized_state =
175-
getResizedState(ndr, itemsPerThread);
176-
resized_state.update(g0, g1, g2);
177-
kernel._subhandler(kernel.getArgs().data(), &resized_state);
178-
}));
179-
}
180-
// Peel the remaining work items. Since the local size is 1, we iterate
181-
// over the work groups.
182-
for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
183-
g0++) {
184-
state.update(g0, g1, g2);
185-
kernel->_subhandler(kernel->getArgs().data(), &state);
186-
}
187-
}
188-
}
189-
190-
} else {
191-
// We are running a parallel_for over an nd_range
192-
193-
if (numWG1 * numWG2 >= numParallelThreads) {
194-
// Dimensions 1 and 2 have enough work, split them across the threadpool
195-
for (unsigned g2 = 0; g2 < numWG2; g2++) {
196-
for (unsigned g1 = 0; g1 < numWG1; g1++) {
197-
futures.emplace_back(
198-
tp.schedule_task([state, &kernel = *kernel, numWG0, g1, g2,
199-
numParallelThreads](size_t threadId) mutable {
200-
for (unsigned g0 = 0; g0 < numWG0; g0++) {
201-
state.update(g0, g1, g2);
165+
for (size_t local2 = 0; local2 < localSize[2]; ++local2) {
166+
for (size_t local1 = 0; local1 < localSize[1]; ++local1) {
167+
for (size_t local0 = 0; local0 < localSize[0]; ++local0) {
168+
state.update(g0, g1, g2, local0, local1, local2);
202169
kernel._subhandler(
203170
kernel.getArgs(numParallelThreads, threadId).data(),
204171
&state);
205172
}
206-
}));
207-
}
208-
}
209-
} else {
210-
// Split dimension 0 across the threadpool
211-
// Here we try to create groups of workgroups in order to reduce
212-
// synchronization overhead
213-
for (unsigned g2 = 0; g2 < numWG2; g2++) {
214-
for (unsigned g1 = 0; g1 < numWG1; g1++) {
215-
for (unsigned g0 = 0; g0 < numWG0; g0++) {
216-
groups.push_back([state, g0, g1, g2, numParallelThreads](
217-
size_t threadId,
218-
ur_kernel_handle_t_ &kernel) mutable {
219-
state.update(g0, g1, g2);
220-
kernel._subhandler(
221-
kernel.getArgs(numParallelThreads, threadId).data(), &state);
222-
});
223-
}
224-
}
225-
}
226-
auto numGroups = groups.size();
227-
auto groupsPerThread = numGroups / numParallelThreads;
228-
if (groupsPerThread) {
229-
for (unsigned thread = 0; thread < numParallelThreads; thread++) {
230-
futures.emplace_back(
231-
tp.schedule_task([groups, thread, groupsPerThread,
232-
&kernel = *kernel](size_t threadId) {
233-
for (unsigned i = 0; i < groupsPerThread; i++) {
234-
auto index = thread * groupsPerThread + i;
235-
groups[index](threadId, kernel);
236-
}
237-
}));
238-
}
239-
}
240-
241-
// schedule the remaining tasks
242-
auto remainder = numGroups % numParallelThreads;
243-
if (remainder) {
244-
futures.emplace_back(
245-
tp.schedule_task([groups, remainder,
246-
scheduled = numParallelThreads * groupsPerThread,
247-
&kernel = *kernel](size_t threadId) {
248-
for (unsigned i = 0; i < remainder; i++) {
249-
auto index = scheduled + i;
250-
groups[index](threadId, kernel);
251173
}
252-
}));
253-
}
254-
}
174+
}
175+
#endif
176+
if (++g0 == numWG0) {
177+
g0 = 0;
178+
if (++g1 == numWG1) {
179+
g1 = 0;
180+
++g2;
181+
}
182+
}
183+
}
184+
}));
185+
rangeStart = rangeEnd;
255186
}
256-
257-
#endif // NATIVECPU_USE_OCK
258187
event->set_futures(futures);
259188

260189
if (phEvent) {

0 commit comments

Comments
 (0)