@@ -273,8 +273,7 @@ inclusive_scan_base_step_blocked(sycl::queue &exec_q,
273273
274274 outputT wg_iscan_val;
275275 if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
276- outputT>::value)
277- {
276+ outputT>::value) {
278277 wg_iscan_val = sycl::inclusive_scan_over_group (
279278 it.get_group (), local_iscan.back (), scan_op, identity);
280279 }
@@ -447,8 +446,7 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,
447446
448447 outputT wg_iscan_val;
449448 if constexpr (can_use_inclusive_scan_over_group<ScanOpT,
450- outputT>::value)
451- {
449+ outputT>::value) {
452450 wg_iscan_val = sycl::inclusive_scan_over_group (
453451 it.get_group (), local_iscan.back (), scan_op, identity);
454452 }
@@ -472,35 +470,32 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,
472470 it.barrier (sycl::access::fence_space::local_space);
473471
474472 // convert back to blocked layout
475- {
476- {
477- const std::uint32_t local_offset0 = lid * n_wi;
473+ {{const std::uint32_t local_offset0 = lid * n_wi;
478474#pragma unroll
479- for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
480- slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
481- }
475+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
476+ slm_iscan_tmp[local_offset0 + m_wi] = local_iscan[m_wi];
477+ }
482478
483- it.barrier (sycl::access::fence_space::local_space);
479+ it.barrier (sycl::access::fence_space::local_space);
484480 }
485481 }
486482
487483 {
488- const std::uint32_t block_offset =
489- sgroup_id * sgSize * n_wi + lane_id;
484+ const std::uint32_t block_offset = sgroup_id * sgSize * n_wi + lane_id;
490485#pragma unroll
491- for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
492- const std::uint32_t m_wi_scaled = m_wi * sgSize;
493- const std::size_t out_id = inp_id0 + m_wi_scaled;
494- if (out_id < acc_nelems) {
495- output[out_iter_offset + out_indexer (out_id)] =
496- slm_iscan_tmp[block_offset + m_wi_scaled];
497- }
498- }
486+ for (nwiT m_wi = 0 ; m_wi < n_wi; ++m_wi) {
487+ const std::uint32_t m_wi_scaled = m_wi * sgSize;
488+ const std::size_t out_id = inp_id0 + m_wi_scaled;
489+ if (out_id < acc_nelems) {
490+ output[out_iter_offset + out_indexer (out_id)] =
491+ slm_iscan_tmp[block_offset + m_wi_scaled];
499492 }
500- });
501- });
493+ }
494+ }
495+ });
496+ });
502497
503- return inc_scan_phase1_ev;
498+ return inc_scan_phase1_ev;
504499}
505500
506501template <typename inputT,
@@ -530,6 +525,8 @@ inclusive_scan_base_step(sycl::queue &exec_q,
530525 std::size_t &acc_groups,
531526 const std::vector<sycl::event> &depends = {})
532527{
528+ // For small stride use striped load/store.
529+ // Threshold value chosen experimentally.
533530 if (s1 <= 16 ) {
534531 return inclusive_scan_base_step_striped<
535532 inputT, outputT, n_wi, IterIndexerT, InpIndexerT, OutIndexerT,
0 commit comments