4040#include " kth_element1d.hpp"
4141#include " partitioning.hpp"
4242
43- // #include <iostream>
43+ #include < iostream>
44+ #include < chrono>
4445
4546namespace sycl_exp = sycl::ext::oneapi::experimental;
4647namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
@@ -67,6 +68,7 @@ struct KthElementF
6768 State<T> &state,
6869 uint64_t items_to_sort,
6970 uint64_t limit,
71+ bool ret,
7072 const std::vector<sycl::event> &deps)
7173 {
7274 auto e = queue.submit ([&](sycl::handler &cgh) {
@@ -84,6 +86,12 @@ struct KthElementF
8486 auto scratch = sycl::local_accessor<std::byte, 1 >(
8587 sycl::range<1 >(temp_memory_size), cgh);
8688
89+ // std::cout << "temp_memory_size: " << temp_memory_size
90+ // << " items_to_sort: " << items_to_sort
91+ // << " limit: " << limit
92+ // << " group_size: " << group_size << "\n";
93+
94+ // auto str = sycl::stream(8192, 1024, cgh);
8795 cgh.parallel_for <pick_pivot_kernel<T>>(
8896 work_sz, [=](sycl::nd_item<1 > item) {
8997 auto group = item.get_group ();
@@ -129,7 +137,6 @@ struct KthElementF
129137 target_found = true ;
130138 }
131139 }
132-
133140 state.reset_iteration_counters ();
134141 }
135142
@@ -142,10 +149,15 @@ struct KthElementF
142149 return ;
143150 }
144151
152+ // if (group.leader()) {
153+ // str << "num_elems: " << num_elems << "\n";
154+ // }
155+
145156 if (num_elems <= limit) {
146157 auto gh = sycl_exp::group_with_scratchpad (
147158 group, sycl::span{&scratch[0 ], temp_memory_size});
148- sycl_exp::joint_sort (gh, &_in[0 ], &_in[num_elems]);
159+ if (num_elems > 0 )
160+ sycl_exp::joint_sort (gh, &_in[0 ], &_in[num_elems]);
149161
150162 if (group.leader ()) {
151163 uint64_t offset = state.counters .less_count [0 ];
@@ -154,9 +166,18 @@ struct KthElementF
154166 state.counters .less_count [0 ] - num_elems;
155167 }
156168
157- uint64_t idx = target - offset;
158- state.values [0 ] = _in[idx];
159- state.values [1 ] = _in[idx + 1 ];
169+ int64_t idx = target - offset;
170+
171+ // if (idx + 1 > (in + state.n - _in) || idx < 0)
172+ // {
173+ // str << "buffer access out of bounds idx = "
174+ // << idx << " size " << (in + state.n - _in) << "\n";
175+ // }
176+ // else
177+ {
178+ state.values [0 ] = _in[idx];
179+ state.values [1 ] = _in[idx + 1 ];
180+ }
160181
161182 state.stop [0 ] = true ;
162183 state.target_found [0 ] = true ;
@@ -165,6 +186,9 @@ struct KthElementF
165186 return ;
166187 }
167188
189+ // if (ret)
190+ // return;
191+
168192 uint64_t step = num_elems / items_to_sort;
169193 for (uint32_t i = llid; i < items_to_sort; i += local_size)
170194 {
@@ -184,30 +208,30 @@ struct KthElementF
184208
185209 T new_pivot = loc_items[items_to_sort / 2 ];
186210
187- if (new_pivot != state.pivot [0 ]) {
211+ // if (new_pivot != state.pivot[0]) {
188212 if (group.leader ()) {
189213 state.pivot [0 ] = new_pivot;
190214 state.num_elems [0 ] = num_elems;
191215 }
192216 return ;
193- }
194-
195- auto start = llid + items_to_sort / 2 + 1 ;
196- uint32_t index = start;
197- for (uint32_t i = start; i < items_to_sort; i += local_size)
198- {
199- if (loc_items[i] != new_pivot) {
200- index = i;
201- break ;
202- }
203- }
204-
205- index = sycl::reduce_over_group (group, index,
206- sycl::minimum<>());
207- if (group.leader ()) {
208- state.pivot [0 ] = loc_items[index];
209- state.num_elems [0 ] = num_elems;
210- }
217+ // }
218+
219+ // auto start = llid + items_to_sort / 2 + 1;
220+ // uint32_t index = start;
221+ // for (uint32_t i = start; i < items_to_sort; i += local_size)
222+ // {
223+ // if (loc_items[i] != new_pivot) {
224+ // index = i;
225+ // break;
226+ // }
227+ // }
228+
229+ // index = sycl::reduce_over_group(group, index,
230+ // sycl::minimum<>());
231+ // if (group.leader()) {
232+ // state.pivot[0] = loc_items[index];
233+ // state.num_elems[0] = num_elems;
234+ // }
211235 });
212236 });
213237
@@ -225,7 +249,7 @@ struct KthElementF
225249 auto e = exec_q.submit ([&](sycl::handler &cgh) {
226250 cgh.depends_on (deps);
227251
228- constexpr uint32_t WorkPI = 4 ; // empirically found number
252+ constexpr uint32_t WorkPI = 1 ; // empirically found number
229253
230254 auto work_range = make_ndrange (state.n , group_size, WorkPI);
231255 submit_partition_one_pivot<T, WorkPI>(cgh, work_range, in, out,
@@ -243,47 +267,57 @@ struct KthElementF
243267 PartitionState<T> &pstate,
244268 const std::vector<sycl::event> &depends)
245269 {
246- uint32_t items_to_sort = 128 ;
247- uint32_t limit = 4 * items_to_sort;
270+ uint32_t items_to_sort = 127 ;
271+ uint32_t limit = 4 * ( items_to_sort + 1 ) ;
248272 uint32_t iterations =
249- std::ceil (std::log (double (state.n ) / limit) / std::log (2 ));
273+ std::ceil (-std::log (double (state.n ) / limit) / std::log (0.536 )) + 1 ;
274+ // Ensure iterations are odd so the final result is always stored in 'partitioned'
275+ iterations += 1 - iterations % 2 ;
250276
251277 auto temp_buff = dpctl_utils::smart_malloc<T>(state.n , exec_q,
252278 sycl::usm::alloc::device);
253279
280+ std::cout << " Iteration " << 0 << std::endl;
254281 auto prev = run_pick_pivot (exec_q, const_cast <T *>(in), partitioned, k,
255- state, items_to_sort, limit, depends);
282+ state, items_to_sort, limit, false , depends);
256283 prev = run_partition (exec_q, const_cast <T *>(in), partitioned, pstate,
257284 {prev});
285+ // prev.wait();
258286
259287 T *_in = partitioned;
260288 T *_out = temp_buff.get ();
261289 for (uint32_t i = 0 ; i < iterations - 1 ; ++i) {
262- prev = run_pick_pivot (exec_q, _in, _out, k, state, limit,
263- items_to_sort, {prev});
290+ std::cout << " Iteration " << i + 1 << std::endl;
291+ prev = run_pick_pivot (exec_q, _in, _out, k, state,
292+ items_to_sort, limit, true , {prev});
264293 prev = run_partition (exec_q, _in, _out, pstate, {prev});
265294 std::swap (_in, _out);
295+ // prev.wait();
296+ // if (i % 5 == 0)
297+ // prev.wait();
266298 }
267- prev = run_pick_pivot (exec_q, _in, _out, k, state, limit, items_to_sort ,
268- {prev});
299+ prev = run_pick_pivot (exec_q, _in, _out, k, state, items_to_sort, limit ,
300+ true , {prev});
269301
270302 return prev;
271303 }
272304
273- static std::tuple< bool , uint64_t , uint64_t , uint64_t >
305+ static KthElement1d::RetT
274306 impl (sycl::queue &exec_queue,
275307 const void *v_ain,
276308 void *v_partitioned,
277309 const size_t a_size,
278310 const size_t k,
279311 const std::vector<sycl::event> &depends)
280312 {
313+ auto start = std::chrono::high_resolution_clock::now ();
281314 const T *ain = static_cast <const T *>(v_ain);
282315 T *partitioned = static_cast <T *>(v_partitioned);
283316
284317 State<T> state (exec_queue, a_size, partitioned);
285318 PartitionState<T> pstate (state);
286319
320+ exec_queue.wait ();
287321 auto init_e = state.init (exec_queue, depends);
288322 init_e = pstate.init (exec_queue, {init_e});
289323
@@ -295,34 +329,60 @@ struct KthElementF
295329 uint64_t less_count = 0 ;
296330 uint64_t greater_equal_count = 0 ;
297331 uint64_t num_elems = 0 ;
332+ uint64_t nan_count = 0 ;
298333 auto copy_evt = exec_queue.copy (state.target_found , &found, 1 , evt);
299334 copy_evt = exec_queue.copy (state.left , &left, 1 , copy_evt);
300335 copy_evt = exec_queue.copy (state.counters .less_count , &less_count, 1 ,
301336 copy_evt);
302337 copy_evt = exec_queue.copy (state.counters .greater_equal_count ,
303338 &greater_equal_count, 1 , copy_evt);
304339 copy_evt = exec_queue.copy (state.num_elems , &num_elems, 1 , copy_evt);
305-
306- copy_evt.wait ();
340+ copy_evt = exec_queue.copy (state.counters .nan_count , &nan_count, 1 , copy_evt);
307341
308342 uint64_t buff_offset = 0 ;
309343 uint64_t elems_offset = less_count;
310- if (!found) {
311- if (left) {
312- elems_offset = less_count - num_elems;
344+
345+ try
346+ {
347+ copy_evt.wait ();
348+
349+ if (!found) {
350+ if (left) {
351+ elems_offset = less_count - num_elems;
352+ }
353+ else {
354+ buff_offset = a_size - num_elems;
355+ }
313356 }
314357 else {
315- buff_offset = a_size - num_elems;
358+ num_elems = 2 ;
359+ elems_offset = k;
316360 }
361+
362+ state.cleanup (exec_queue);
363+ auto end = std::chrono::high_resolution_clock::now ();
364+
365+ auto duration =
366+ std::chrono::duration_cast<std::chrono::microseconds>(end - start)
367+ .count ();
368+
369+ std::cout << " KthElement1d took " << duration << " microseconds"
370+ << std::endl;
371+
372+ std::cout << " Found " << found << " left " << left
373+ << " less_count " << less_count
374+ << " greater_equal_count " << greater_equal_count
375+ << " num_elems " << num_elems
376+ << " nan_count " << nan_count
377+ << std::endl;
378+ /* code */
317379 }
318- else {
319- num_elems = 2 ;
320- elems_offset = k ;
380+ catch (sycl::exception const &e)
381+ {
382+ std::cout << e. what () << std::endl ;
321383 }
322384
323- state.cleanup (exec_queue);
324-
325- return {found, buff_offset, elems_offset, num_elems};
385+ return {found, buff_offset, elems_offset, num_elems, nan_count};
326386 }
327387};
328388
@@ -335,7 +395,7 @@ KthElement1d::KthElement1d() : dispatch_table("a")
335395 dispatch_table.populate_dispatch_table <SupportedTypes, KthElementF>();
336396}
337397
338- std::tuple< bool , uint64_t , uint64_t , uint64_t >
398+ KthElement1d::RetT
339399 KthElement1d::call (const dpctl::tensor::usm_ndarray &a,
340400 dpctl::tensor::usm_ndarray &partitioned,
341401 const size_t k,
0 commit comments