@@ -52,17 +52,6 @@ struct NDRDescT {
52
52
};
53
53
} // namespace native_cpu
54
54
55
- #ifdef NATIVECPU_USE_OCK
56
- static native_cpu::state getResizedState (const native_cpu::NDRDescT &ndr,
57
- size_t itemsPerThread) {
58
- native_cpu::state resized_state (
59
- ndr.GlobalSize [0 ], ndr.GlobalSize [1 ], ndr.GlobalSize [2 ], itemsPerThread,
60
- ndr.LocalSize [1 ], ndr.LocalSize [2 ], ndr.GlobalOffset [0 ],
61
- ndr.GlobalOffset [1 ], ndr.GlobalOffset [2 ]);
62
- return resized_state;
63
- }
64
- #endif
65
-
66
55
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
67
56
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
68
57
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -112,6 +101,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
112
101
// TODO: add proper error checking
113
102
native_cpu::NDRDescT ndr (workDim, pGlobalWorkOffset, pGlobalWorkSize,
114
103
pLocalWorkSize);
104
+ unsigned long long numWI;
105
+ auto umulll_overflow = [](unsigned long long a, unsigned long long b,
106
+ unsigned long long *c) -> bool {
107
+ #ifdef __GNUC__
108
+ return __builtin_umulll_overflow (a, b, c);
109
+ #else
110
+ *c = a * b;
111
+ return a != 0 && b != *c / a;
112
+ #endif
113
+ };
114
+ if (umulll_overflow (ndr.GlobalSize [0 ], ndr.GlobalSize [1 ], &numWI) ||
115
+ umulll_overflow (numWI, ndr.GlobalSize [2 ], &numWI) || numWI > SIZE_MAX) {
116
+ return UR_RESULT_ERROR_OUT_OF_RESOURCES;
117
+ }
118
+
115
119
auto &tp = hQueue->getDevice ()->tp ;
116
120
const size_t numParallelThreads = tp.num_threads ();
117
121
std::vector<std::future<void >> futures;
@@ -130,131 +134,56 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
130
134
auto kernel = std::make_unique<ur_kernel_handle_t_>(*hKernel);
131
135
kernel->updateMemPool (numParallelThreads);
132
136
137
+ const size_t numWG = numWG0 * numWG1 * numWG2;
138
+ const size_t numWGPerThread = numWG / numParallelThreads;
139
+ const size_t remainderWG = numWG - numWGPerThread * numParallelThreads;
140
+ // The fourth value is the linearized value.
141
+ std::array<size_t , 4 > rangeStart = {0 , 0 , 0 , 0 };
142
+ for (unsigned t = 0 ; t < numParallelThreads; ++t) {
143
+ auto rangeEnd = rangeStart;
144
+ rangeEnd[3 ] += numWGPerThread + (t < remainderWG);
145
+ if (rangeEnd[3 ] == rangeStart[3 ])
146
+ break ;
147
+ rangeEnd[0 ] = rangeEnd[3 ] % numWG0;
148
+ rangeEnd[1 ] = (rangeEnd[3 ] / numWG0) % numWG1;
149
+ rangeEnd[2 ] = rangeEnd[3 ] / (numWG0 * numWG1);
150
+ futures.emplace_back (
151
+ tp.schedule_task ([state, &kernel = *kernel, rangeStart,
152
+ rangeEnd = rangeEnd[3 ], numWG0, numWG1,
133
153
#ifndef NATIVECPU_USE_OCK
134
- for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
135
- for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
136
- for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
137
- for (unsigned local2 = 0 ; local2 < ndr.LocalSize [2 ]; local2++) {
138
- for (unsigned local1 = 0 ; local1 < ndr.LocalSize [1 ]; local1++) {
139
- for (unsigned local0 = 0 ; local0 < ndr.LocalSize [0 ]; local0++) {
140
- state.update (g0, g1, g2, local0, local1, local2);
141
- kernel->_subhandler (kernel->getArgs (1 , 0 ).data (), &state);
142
- }
143
- }
144
- }
145
- }
146
- }
147
- }
154
+ localSize = ndr.LocalSize ,
155
+ #endif
156
+ numParallelThreads](size_t threadId) mutable {
157
+ for (size_t g0 = rangeStart[0 ], g1 = rangeStart[1 ],
158
+ g2 = rangeStart[2 ], g3 = rangeStart[3 ];
159
+ g3 < rangeEnd; ++g3) {
160
+ #ifdef NATIVECPU_USE_OCK
161
+ state.update (g0, g1, g2);
162
+ kernel._subhandler (
163
+ kernel.getArgs (numParallelThreads, threadId).data (), &state);
148
164
#else
149
- bool isLocalSizeOne =
150
- ndr.LocalSize [0 ] == 1 && ndr.LocalSize [1 ] == 1 && ndr.LocalSize [2 ] == 1 ;
151
- if (isLocalSizeOne && ndr.GlobalSize [0 ] > numParallelThreads &&
152
- !kernel->hasLocalArgs ()) {
153
- // If the local size is one, we make the assumption that we are running a
154
- // parallel_for over a sycl::range.
155
- // Todo: we could add more compiler checks and
156
- // kernel properties for this (e.g. check that no barriers are called).
157
-
158
- // Todo: this assumes that dim 0 is the best dimension over which we want to
159
- // parallelize
160
-
161
- // Since we also vectorize the kernel, and vectorization happens within the
162
- // work group loop, it's better to have a large-ish local size. We can
163
- // divide the global range by the number of threads, set that as the local
164
- // size and peel everything else.
165
-
166
- size_t new_num_work_groups_0 = numParallelThreads;
167
- size_t itemsPerThread = ndr.GlobalSize [0 ] / numParallelThreads;
168
-
169
- for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
170
- for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
171
- for (unsigned g0 = 0 ; g0 < new_num_work_groups_0; g0 += 1 ) {
172
- futures.emplace_back (tp.schedule_task (
173
- [ndr, itemsPerThread, &kernel = *kernel, g0, g1, g2](size_t ) {
174
- native_cpu::state resized_state =
175
- getResizedState (ndr, itemsPerThread);
176
- resized_state.update (g0, g1, g2);
177
- kernel._subhandler (kernel.getArgs ().data (), &resized_state);
178
- }));
179
- }
180
- // Peel the remaining work items. Since the local size is 1, we iterate
181
- // over the work groups.
182
- for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
183
- g0++) {
184
- state.update (g0, g1, g2);
185
- kernel->_subhandler (kernel->getArgs ().data (), &state);
186
- }
187
- }
188
- }
189
-
190
- } else {
191
- // We are running a parallel_for over an nd_range
192
-
193
- if (numWG1 * numWG2 >= numParallelThreads) {
194
- // Dimensions 1 and 2 have enough work, split them across the threadpool
195
- for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
196
- for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
197
- futures.emplace_back (
198
- tp.schedule_task ([state, &kernel = *kernel, numWG0, g1, g2,
199
- numParallelThreads](size_t threadId) mutable {
200
- for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
201
- state.update (g0, g1, g2);
165
+ for (size_t local2 = 0 ; local2 < localSize[2 ]; ++local2) {
166
+ for (size_t local1 = 0 ; local1 < localSize[1 ]; ++local1) {
167
+ for (size_t local0 = 0 ; local0 < localSize[0 ]; ++local0) {
168
+ state.update (g0, g1, g2, local0, local1, local2);
202
169
kernel._subhandler (
203
170
kernel.getArgs (numParallelThreads, threadId).data (),
204
171
&state);
205
172
}
206
- }));
207
- }
208
- }
209
- } else {
210
- // Split dimension 0 across the threadpool
211
- // Here we try to create groups of workgroups in order to reduce
212
- // synchronization overhead
213
- for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
214
- for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
215
- for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
216
- groups.push_back ([state, g0, g1, g2, numParallelThreads](
217
- size_t threadId,
218
- ur_kernel_handle_t_ &kernel) mutable {
219
- state.update (g0, g1, g2);
220
- kernel._subhandler (
221
- kernel.getArgs (numParallelThreads, threadId).data (), &state);
222
- });
223
- }
224
- }
225
- }
226
- auto numGroups = groups.size ();
227
- auto groupsPerThread = numGroups / numParallelThreads;
228
- if (groupsPerThread) {
229
- for (unsigned thread = 0 ; thread < numParallelThreads; thread++) {
230
- futures.emplace_back (
231
- tp.schedule_task ([groups, thread, groupsPerThread,
232
- &kernel = *kernel](size_t threadId) {
233
- for (unsigned i = 0 ; i < groupsPerThread; i++) {
234
- auto index = thread * groupsPerThread + i;
235
- groups[index](threadId, kernel);
236
- }
237
- }));
238
- }
239
- }
240
-
241
- // schedule the remaining tasks
242
- auto remainder = numGroups % numParallelThreads;
243
- if (remainder) {
244
- futures.emplace_back (
245
- tp.schedule_task ([groups, remainder,
246
- scheduled = numParallelThreads * groupsPerThread,
247
- &kernel = *kernel](size_t threadId) {
248
- for (unsigned i = 0 ; i < remainder; i++) {
249
- auto index = scheduled + i;
250
- groups[index](threadId, kernel);
251
173
}
252
- }));
253
- }
254
- }
174
+ }
175
+ #endif
176
+ if (++g0 == numWG0) {
177
+ g0 = 0 ;
178
+ if (++g1 == numWG1) {
179
+ g1 = 0 ;
180
+ ++g2;
181
+ }
182
+ }
183
+ }
184
+ }));
185
+ rangeStart = rangeEnd;
255
186
}
256
-
257
- #endif // NATIVECPU_USE_OCK
258
187
event->set_futures (futures);
259
188
260
189
if (phEvent) {
0 commit comments