4040#include " kth_element1d.hpp"
4141#include " partitioning.hpp"
4242
43- #include < iostream>
4443#include < chrono>
44+ #include < iostream>
4545
4646namespace sycl_exp = sycl::ext::oneapi::experimental;
4747namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
@@ -152,7 +152,8 @@ struct KthElementF
152152 auto gh = sycl_exp::group_with_scratchpad (
153153 group, sycl::span{&scratch[0 ], temp_memory_size});
154154 if (num_elems > 0 )
155- sycl_exp::joint_sort (gh, &_in[0 ], &_in[num_elems], Less<T>{});
155+ sycl_exp::joint_sort (gh, &_in[0 ], &_in[num_elems],
156+ Less<T>{});
156157
157158 if (group.leader ()) {
158159 uint64_t offset = state.counters .less_count [0 ];
@@ -188,7 +189,8 @@ struct KthElementF
188189 auto gh = sycl_exp::group_with_scratchpad (
189190 group, sycl::span{&scratch[0 ], temp_memory_size});
190191 sycl_exp::joint_sort (gh, &loc_items[0 ],
191- &loc_items[0 ] + items_to_sort, Less<T>{});
192+ &loc_items[0 ] + items_to_sort,
193+ Less<T>{});
192194
193195 T new_pivot = loc_items[items_to_sort / 2 ];
194196
@@ -256,7 +258,8 @@ struct KthElementF
256258 uint32_t limit = 4 * (items_to_sort + 1 );
257259 uint32_t iterations =
258260 std::ceil (-std::log (double (state.n ) / limit) / std::log (0.536 )) + 1 ;
259- // Ensure iterations are odd so the final result is always stored in 'partitioned'
261+ // Ensure iterations are odd so the final result is always stored in
262+ // 'partitioned'
260263 iterations += 1 - iterations % 2 ;
261264
262265 auto prev = run_pick_pivot (exec_q, const_cast <T *>(in), partitioned, k,
@@ -267,8 +270,8 @@ struct KthElementF
267270 T *_in = partitioned;
268271 T *_out = temp_buff;
269272 for (uint32_t i = 0 ; i < iterations - 1 ; ++i) {
270- prev = run_pick_pivot (exec_q, _in, _out, k, state,
271- items_to_sort, limit, {prev});
273+ prev = run_pick_pivot (exec_q, _in, _out, k, state, items_to_sort,
274+ limit, {prev});
272275 prev = run_partition (exec_q, _in, _out, pstate, {prev});
273276 std::swap (_in, _out);
274277 }
@@ -278,13 +281,12 @@ struct KthElementF
278281 return prev;
279282 }
280283
281- static KthElement1d::RetT
282- impl (sycl::queue &exec_queue,
283- const void *v_ain,
284- void *v_partitioned,
285- const size_t a_size,
286- const size_t k,
287- const std::vector<sycl::event> &depends)
284+ static KthElement1d::RetT impl (sycl::queue &exec_queue,
285+ const void *v_ain,
286+ void *v_partitioned,
287+ const size_t a_size,
288+ const size_t k,
289+ const std::vector<sycl::event> &depends)
288290 {
289291 auto start = std::chrono::high_resolution_clock::now ();
290292 const T *ain = static_cast <const T *>(v_ain);
@@ -298,9 +300,9 @@ struct KthElementF
298300 init_e = pstate.init (exec_queue, {init_e});
299301
300302 auto temp_buff = dpctl_utils::smart_malloc<T>(state.n , exec_queue,
301- sycl::usm::alloc::device);
302- auto evt = run_kth_element (exec_queue, ain, partitioned, temp_buff. get (), k, state,
303- pstate, {init_e});
303+ sycl::usm::alloc::device);
304+ auto evt = run_kth_element (exec_queue, ain, partitioned,
305+ temp_buff. get (), k, state, pstate, {init_e});
304306
305307 bool found = false ;
306308 bool left = false ;
@@ -315,7 +317,8 @@ struct KthElementF
315317 copy_evt = exec_queue.copy (state.counters .greater_equal_count ,
316318 &greater_equal_count, 1 , copy_evt);
317319 copy_evt = exec_queue.copy (state.num_elems , &num_elems, 1 , copy_evt);
318- copy_evt = exec_queue.copy (state.counters .nan_count , &nan_count, 1 , copy_evt);
320+ copy_evt =
321+ exec_queue.copy (state.counters .nan_count , &nan_count, 1 , copy_evt);
319322
320323 uint64_t buff_offset = 0 ;
321324 uint64_t elems_offset = less_count;
@@ -344,25 +347,30 @@ struct KthElementF
344347 .count ();
345348
346349 std::cout << " KthElement1d took " << duration << " microseconds"
347- << std::endl;
350+ << std::endl;
348351 return {found, buff_offset, elems_offset, num_elems, nan_count};
349352 }
350353};
351354
352- using SupportedTypes =
353- std::tuple<uint32_t , int32_t , uint64_t , int64_t , float , double , std::complex <float >, std::complex <double >>;
355+ using SupportedTypes = std::tuple<uint32_t ,
356+ int32_t ,
357+ uint64_t ,
358+ int64_t ,
359+ float ,
360+ double ,
361+ std::complex <float >,
362+ std::complex <double >>;
354363} // namespace
355364
356365KthElement1d::KthElement1d () : dispatch_table(" a" )
357366{
358367 dispatch_table.populate_dispatch_table <SupportedTypes, KthElementF>();
359368}
360369
361- KthElement1d::RetT
362- KthElement1d::call (const dpctl::tensor::usm_ndarray &a,
363- dpctl::tensor::usm_ndarray &partitioned,
364- const size_t k,
365- const std::vector<sycl::event> &depends)
370+ KthElement1d::RetT KthElement1d::call (const dpctl::tensor::usm_ndarray &a,
371+ dpctl::tensor::usm_ndarray &partitioned,
372+ const size_t k,
373+ const std::vector<sycl::event> &depends)
366374{
367375 // validate(a, partitioned, k);
368376
0 commit comments