@@ -68,7 +68,6 @@ struct KthElementF
6868 State<T> &state,
6969 uint64_t items_to_sort,
7070 uint64_t limit,
71- bool ret,
7271 const std::vector<sycl::event> &deps)
7372 {
7473 auto e = queue.submit ([&](sycl::handler &cgh) {
@@ -149,15 +148,11 @@ struct KthElementF
149148 return ;
150149 }
151150
152- // if (group.leader()) {
153- // str << "num_elems: " << num_elems << "\n";
154- // }
155-
156151 if (num_elems <= limit) {
157152 auto gh = sycl_exp::group_with_scratchpad (
158153 group, sycl::span{&scratch[0 ], temp_memory_size});
159154 if (num_elems > 0 )
160- sycl_exp::joint_sort (gh, &_in[0 ], &_in[num_elems]);
155+ sycl_exp::joint_sort (gh, &_in[0 ], &_in[num_elems], Less<T>{} );
161156
162157 if (group.leader ()) {
163158 uint64_t offset = state.counters .less_count [0 ];
@@ -168,16 +163,8 @@ struct KthElementF
168163
169164 int64_t idx = target - offset;
170165
171- // if (idx + 1 > (in + state.n - _in) || idx < 0)
172- // {
173- // str << "buffer access out of bounds idx = "
174- // << idx << " size " << (in + state.n - _in) << "\n";
175- // }
176- // else
177- {
178- state.values [0 ] = _in[idx];
179- state.values [1 ] = _in[idx + 1 ];
180- }
166+ state.values [0 ] = _in[idx];
167+ state.values [1 ] = _in[idx + 1 ];
181168
182169 state.stop [0 ] = true ;
183170 state.target_found [0 ] = true ;
@@ -186,9 +173,6 @@ struct KthElementF
186173 return ;
187174 }
188175
189- // if (ret)
190- // return;
191-
192176 uint64_t step = num_elems / items_to_sort;
193177 for (uint32_t i = llid; i < items_to_sort; i += local_size)
194178 {
@@ -204,34 +188,34 @@ struct KthElementF
204188 auto gh = sycl_exp::group_with_scratchpad (
205189 group, sycl::span{&scratch[0 ], temp_memory_size});
206190 sycl_exp::joint_sort (gh, &loc_items[0 ],
207- &loc_items[0 ] + items_to_sort);
191+ &loc_items[0 ] + items_to_sort, Less<T>{} );
208192
209193 T new_pivot = loc_items[items_to_sort / 2 ];
210194
211- // if (new_pivot != state.pivot[0]) {
195+ if (new_pivot != state.pivot [0 ]) {
212196 if (group.leader ()) {
213197 state.pivot [0 ] = new_pivot;
214198 state.num_elems [0 ] = num_elems;
215199 }
216200 return ;
217- // }
218-
219- // auto start = llid + items_to_sort / 2 + 1;
220- // uint32_t index = start;
221- // for (uint32_t i = start; i < items_to_sort; i += local_size)
222- // {
223- // if (loc_items[i] != new_pivot) {
224- // index = i;
225- // break;
226- // }
227- // }
228-
229- // index = sycl::reduce_over_group(group, index,
230- // sycl::minimum<>());
231- // if (group.leader()) {
232- // state.pivot[0] = loc_items[index];
233- // state.num_elems[0] = num_elems;
234- // }
201+ }
202+
203+ auto start = llid + items_to_sort / 2 + 1 ;
204+ uint32_t index = start;
205+ for (uint32_t i = start; i < items_to_sort; i += local_size)
206+ {
207+ if (loc_items[i] != new_pivot) {
208+ index = i;
209+ break ;
210+ }
211+ }
212+
213+ index = sycl::reduce_over_group (group, index,
214+ sycl::minimum<>());
215+ if (group.leader ()) {
216+ state.pivot [0 ] = loc_items[index];
217+ state.num_elems [0 ] = num_elems;
218+ }
235219 });
236220 });
237221
@@ -262,6 +246,7 @@ struct KthElementF
262246 static sycl::event run_kth_element (sycl::queue &exec_q,
263247 const T *in,
264248 T *partitioned,
249+ T *temp_buff,
265250 const size_t k,
266251 State<T> &state,
267252 PartitionState<T> &pstate,
@@ -274,31 +259,21 @@ struct KthElementF
274259 // Ensure iterations are odd so the final result is always stored in 'partitioned'
275260 iterations += 1 - iterations % 2 ;
276261
277- auto temp_buff = dpctl_utils::smart_malloc<T>(state.n , exec_q,
278- sycl::usm::alloc::device);
279-
280- std::cout << " Iteration " << 0 << std::endl;
281262 auto prev = run_pick_pivot (exec_q, const_cast <T *>(in), partitioned, k,
282- state, items_to_sort, limit, false , depends);
263+ state, items_to_sort, limit, depends);
283264 prev = run_partition (exec_q, const_cast <T *>(in), partitioned, pstate,
284265 {prev});
285- // prev.wait();
286266
287267 T *_in = partitioned;
288- T *_out = temp_buff. get () ;
268+ T *_out = temp_buff;
289269 for (uint32_t i = 0 ; i < iterations - 1 ; ++i) {
290- std::cout << " Iteration " << i + 1 << std::endl;
291270 prev = run_pick_pivot (exec_q, _in, _out, k, state,
292- items_to_sort, limit, true , {prev});
271+ items_to_sort, limit, {prev});
293272 prev = run_partition (exec_q, _in, _out, pstate, {prev});
294273 std::swap (_in, _out);
295- // prev.wait();
296- // if (i % 5 == 0)
297- // prev.wait();
298274 }
299- prev.wait ();
300275 prev = run_pick_pivot (exec_q, _in, _out, k, state, items_to_sort, limit,
301- true , {prev});
276+ {prev});
302277
303278 return prev;
304279 }
@@ -322,7 +297,9 @@ struct KthElementF
322297 auto init_e = state.init (exec_queue, depends);
323298 init_e = pstate.init (exec_queue, {init_e});
324299
325- auto evt = run_kth_element (exec_queue, ain, partitioned, k, state,
300+ auto temp_buff = dpctl_utils::smart_malloc<T>(state.n , exec_queue,
301+ sycl::usm::alloc::device);
302+ auto evt = run_kth_element (exec_queue, ain, partitioned, temp_buff.get (), k, state,
326303 pstate, {init_e});
327304
328305 bool found = false ;
@@ -343,52 +320,37 @@ struct KthElementF
343320 uint64_t buff_offset = 0 ;
344321 uint64_t elems_offset = less_count;
345322
346- try
347- {
348- copy_evt.wait ();
349-
350- if (!found) {
351- if (left) {
352- elems_offset = less_count - num_elems;
353- }
354- else {
355- buff_offset = a_size - num_elems;
356- }
323+ copy_evt.wait ();
324+
325+ if (!found) {
326+ if (left) {
327+ elems_offset = less_count - num_elems;
357328 }
358329 else {
359- num_elems = 2 ;
360- elems_offset = k;
330+ buff_offset = a_size - num_elems;
361331 }
362-
363- auto end = std::chrono::high_resolution_clock::now ();
364-
365- auto duration =
366- std::chrono::duration_cast<std::chrono::microseconds>(end - start)
367- .count ();
368-
369- std::cout << " KthElement1d took " << duration << " microseconds"
370- << std::endl;
371-
372- std::cout << " Found " << found << " left " << left
373- << " less_count " << less_count
374- << " greater_equal_count " << greater_equal_count
375- << " num_elems " << num_elems
376- << " nan_count " << nan_count
377- << std::endl;
378- /* code */
379332 }
380- catch (sycl::exception const &e)
381- {
382- std::cout << e. what () << std::endl ;
333+ else {
334+ num_elems = 2 ;
335+ elems_offset = k ;
383336 }
384337
385338 state.cleanup (exec_queue);
339+
340+ auto end = std::chrono::high_resolution_clock::now ();
341+
342+ auto duration =
343+ std::chrono::duration_cast<std::chrono::microseconds>(end - start)
344+ .count ();
345+
346+ std::cout << " KthElement1d took " << duration << " microseconds"
347+ << std::endl;
386348 return {found, buff_offset, elems_offset, num_elems, nan_count};
387349 }
388350};
389351
390352using SupportedTypes =
391- std::tuple<uint32_t , int32_t , uint64_t , int64_t , float , double >;
353+ std::tuple<uint32_t , int32_t , uint64_t , int64_t , float , double , std:: complex < float >, std:: complex < double > >;
392354} // namespace
393355
394356KthElement1d::KthElement1d () : dispatch_table(" a" )
0 commit comments