66//
77// ===----------------------------------------------------------------------===//
88#include < array>
9+ #include < cstddef>
910#include < cstdint>
11+ #include < vector>
1012
1113#include " ur_api.h"
1214
1315#include " common.hpp"
1416#include " kernel.hpp"
1517#include " memory.hpp"
16- #include " threadpool.hpp"
1718#include " queue.hpp"
19+ #include " threadpool.hpp"
1820
1921namespace native_cpu {
2022struct NDRDescT {
@@ -37,9 +39,29 @@ struct NDRDescT {
3739 GlobalOffset[I] = 0 ;
3840 }
3941 }
42+
43+ void dump (std::ostream &os) const {
44+ os << " GlobalSize: " << GlobalSize[0 ] << " " << GlobalSize[1 ] << " "
45+ << GlobalSize[2 ] << " \n " ;
46+ os << " LocalSize: " << LocalSize[0 ] << " " << LocalSize[1 ] << " "
47+ << LocalSize[2 ] << " \n " ;
48+ os << " GlobalOffset: " << GlobalOffset[0 ] << " " << GlobalOffset[1 ] << " "
49+ << GlobalOffset[2 ] << " \n " ;
50+ }
4051};
4152} // namespace native_cpu
4253
54+ #ifdef NATIVECPU_USE_OCK
55+ static native_cpu::state getResizedState (const native_cpu::NDRDescT &ndr,
56+ size_t itemsPerThread) {
57+ native_cpu::state resized_state (
58+ ndr.GlobalSize [0 ], ndr.GlobalSize [1 ], ndr.GlobalSize [2 ], itemsPerThread,
59+ ndr.LocalSize [1 ], ndr.LocalSize [2 ], ndr.GlobalOffset [0 ],
60+ ndr.GlobalOffset [1 ], ndr.GlobalOffset [2 ]);
61+ return resized_state;
62+ }
63+ #endif
64+
4365UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
4466 ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
4567 const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -61,38 +83,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6183
6284 // TODO: add proper error checking
6385 // TODO: add proper event dep management
64- native_cpu::NDRDescT ndr (workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize);
65- auto & tp = hQueue->device ->tp ;
86+ native_cpu::NDRDescT ndr (workDim, pGlobalWorkOffset, pGlobalWorkSize,
87+ pLocalWorkSize);
88+ auto &tp = hQueue->device ->tp ;
6689 const size_t numParallelThreads = tp.num_threads ();
6790 hKernel->updateMemPool (numParallelThreads);
6891 std::vector<std::future<void >> futures;
92+ std::vector<std::function<void (size_t , ur_kernel_handle_t_)>> groups;
6993 auto numWG0 = ndr.GlobalSize [0 ] / ndr.LocalSize [0 ];
7094 auto numWG1 = ndr.GlobalSize [1 ] / ndr.LocalSize [1 ];
7195 auto numWG2 = ndr.GlobalSize [2 ] / ndr.LocalSize [2 ];
72- bool isLocalSizeOne =
73- ndr.LocalSize [0 ] == 1 && ndr.LocalSize [1 ] == 1 && ndr.LocalSize [2 ] == 1 ;
74-
75-
7696 native_cpu::state state (ndr.GlobalSize [0 ], ndr.GlobalSize [1 ],
7797 ndr.GlobalSize [2 ], ndr.LocalSize [0 ], ndr.LocalSize [1 ],
7898 ndr.LocalSize [2 ], ndr.GlobalOffset [0 ],
7999 ndr.GlobalOffset [1 ], ndr.GlobalOffset [2 ]);
80- if (isLocalSizeOne) {
81- // If the local size is one, we make the assumption that we are running a
82- // parallel_for over a sycl::range Todo: we could add compiler checks and
83- // kernel properties for this (e.g. check that no barriers are called, no
84- // local memory args).
85-
86- auto numWG0 = ndr.GlobalSize [0 ] / ndr.LocalSize [0 ];
87- auto numWG1 = ndr.GlobalSize [1 ] / ndr.LocalSize [1 ];
88- auto numWG2 = ndr.GlobalSize [2 ] / ndr.LocalSize [2 ];
100+ #ifndef NATIVECPU_USE_OCK
101+ hKernel->handleLocalArgs (1 , 0 );
89102 for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
90103 for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
91104 for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
92- #ifdef NATIVECPU_USE_OCK
93- state.update (g0, g1, g2);
94- hKernel->_subhandler (hKernel->_args .data (), &state);
95- #else
96105 for (unsigned local2 = 0 ; local2 < ndr.LocalSize [2 ]; local2++) {
97106 for (unsigned local1 = 0 ; local1 < ndr.LocalSize [1 ]; local1++) {
98107 for (unsigned local0 = 0 ; local0 < ndr.LocalSize [0 ]; local0++) {
@@ -101,13 +110,118 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
101110 }
102111 }
103112 }
104- #endif
113+ }
114+ }
115+ }
116+ #else
117+ bool isLocalSizeOne =
118+ ndr.LocalSize [0 ] == 1 && ndr.LocalSize [1 ] == 1 && ndr.LocalSize [2 ] == 1 ;
119+ if (isLocalSizeOne && ndr.GlobalSize [0 ] > numParallelThreads) {
120+ // If the local size is one, we make the assumption that we are running a
121+ // parallel_for over a sycl::range.
122+ // Todo: we could add compiler checks and
123+ // kernel properties for this (e.g. check that no barriers are called, no
124+ // local memory args).
125+
126+ // Todo: this assumes that dim 0 is the best dimension over which we want to
127+ // parallelize
128+
129+ // Since we also vectorize the kernel, and vectorization happens within the
130+ // work group loop, it's better to have a large-ish local size. We can
131+ // divide the global range by the number of threads, set that as the local
132+ // size and peel everything else.
133+
134+ size_t new_num_work_groups_0 = numParallelThreads;
135+ size_t itemsPerThread = ndr.GlobalSize [0 ] / numParallelThreads;
136+
137+ for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
138+ for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
139+ for (unsigned g0 = 0 ; g0 < new_num_work_groups_0; g0 += 1 ) {
140+ futures.emplace_back (
141+ tp.schedule_task ([&ndr = std::as_const (ndr), itemsPerThread,
142+ hKernel, g0, g1, g2](size_t ) {
143+ native_cpu::state resized_state =
144+ getResizedState (ndr, itemsPerThread);
145+ resized_state.update (g0, g1, g2);
146+ hKernel->_subhandler (hKernel->_args .data (), &resized_state);
147+ }));
148+ }
149+ // Peel the remaining work items. Since the local size is 1, we iterate
150+ // over the work groups.
151+ for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
152+ g0++) {
153+ state.update (g0, g1, g2);
154+ hKernel->_subhandler (hKernel->_args .data (), &state);
155+ }
156+ }
157+ }
158+
159+ } else {
160+ // We are running a parallel_for over an nd_range
161+
162+ if (numWG1 * numWG2 >= numParallelThreads) {
163+ // Dimensions 1 and 2 have enough work, split them across the threadpool
164+ for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
165+ for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
166+ futures.emplace_back (
167+ tp.schedule_task ([state, kernel = *hKernel, numWG0, g1, g2,
168+ numParallelThreads](size_t threadId) mutable {
169+ for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
170+ kernel.handleLocalArgs (numParallelThreads, threadId);
171+ state.update (g0, g1, g2);
172+ kernel._subhandler (kernel._args .data (), &state);
173+ }
174+ }));
175+ }
176+ }
177+ } else {
178+ // Split dimension 0 across the threadpool
179+ // Here we try to create groups of workgroups in order to reduce
180+ // synchronization overhead
181+ for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
182+ for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
183+ for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
184+ groups.push_back (
185+ [state, g0, g1, g2, numParallelThreads](
186+ size_t threadId, ur_kernel_handle_t_ kernel) mutable {
187+ kernel.handleLocalArgs (numParallelThreads, threadId);
188+ state.update (g0, g1, g2);
189+ kernel._subhandler (kernel._args .data (), &state);
190+ });
191+ }
192+ }
193+ }
194+ auto numGroups = groups.size ();
195+ auto groupsPerThread = numGroups / numParallelThreads;
196+ auto remainder = numGroups % numParallelThreads;
197+ for (unsigned thread = 0 ; thread < numParallelThreads; thread++) {
198+ futures.emplace_back (tp.schedule_task (
199+ [&groups, thread, groupsPerThread, hKernel](size_t threadId) {
200+ for (unsigned i = 0 ; i < groupsPerThread; i++) {
201+ auto index = thread * groupsPerThread + i;
202+ groups[index](threadId, *hKernel);
203+ }
204+ }));
205+ }
206+
207+ // schedule the remaining tasks
208+ if (remainder) {
209+ futures.emplace_back (
210+ tp.schedule_task ([&groups, remainder,
211+ scheduled = numParallelThreads * groupsPerThread,
212+ hKernel](size_t threadId) {
213+ for (unsigned i = 0 ; i < remainder; i++) {
214+ auto index = scheduled + i;
215+ groups[index](threadId, *hKernel);
216+ }
217+ }));
105218 }
106219 }
107220 }
108221
109222 for (auto &f : futures)
110223 f.get ();
224+ #endif // NATIVECPU_USE_OCK
111225 // TODO: we should avoid calling clear here by avoiding using push_back
112226 // in setKernelArgs.
113227 hKernel->_args .clear ();
@@ -553,4 +667,3 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
553667
554668 DIE_NO_IMPLEMENTATION;
555669}
556-
0 commit comments