@@ -12,20 +12,20 @@ template <typename T, size_t Rows, size_t Cols, layout Layout, use Use>
1212class matrix_process ;
1313
1414template <typename TResult, typename AccessorType>
15- void reduce_and_accumulate (sub_group sg, size_t sg_size, size_t global_idy,
16- AccessorType &global_acc, TResult *local_sums,
15+ void reduce_and_accumulate (sub_group sg, size_t sg_size, size_t global_idy,
16+ AccessorType &global_acc, TResult *local_sums,
1717 size_t count) {
18- for (size_t i = 0 ; i < count; i++) {
19- local_sums[i] = reduce_over_group (sg, local_sums[i], sycl::plus<>());
20-
21- // Only the subgroup leader performs the global accumulation
22- if (global_idy % sg_size == 0 ) {
23- sycl::atomic_ref<TResult, sycl::memory_order::relaxed,
24- sycl::memory_scope::device>
25- aref (global_acc[i]);
26- aref.fetch_add (local_sums[i]);
27- }
28- }
18+ for (size_t i = 0 ; i < count; i++) {
19+ local_sums[i] = reduce_over_group (sg, local_sums[i], sycl::plus<>());
20+
21+ // Only the subgroup leader performs the global accumulation
22+ if (global_idy % sg_size == 0 ) {
23+ sycl::atomic_ref<TResult, sycl::memory_order::relaxed,
24+ sycl::memory_scope::device>
25+ aref (global_acc[i]);
26+ aref.fetch_add (local_sums[i]);
27+ }
28+ }
2929}
3030
3131template <typename T, typename TResult, size_t NUM_ROWS, size_t NUM_COLS,
@@ -100,8 +100,10 @@ void matrix_sum(big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> &M,
100100 });
101101 }
102102
103- reduce_and_accumulate (sg, sg_size, global_idy, v_rows, sum_local_rows, NUM_ROWS);
104- reduce_and_accumulate (sg, sg_size, global_idy, v_cols, sum_local_cols, NUM_COLS);
103+ reduce_and_accumulate (sg, sg_size, global_idy, v_rows,
104+ sum_local_rows, NUM_ROWS);
105+ reduce_and_accumulate (sg, sg_size, global_idy, v_cols,
106+ sum_local_cols, NUM_COLS);
105107 }); // parallel for
106108 }).wait ();
107109}
@@ -120,8 +122,7 @@ void test_get_coord_op() {
120122 TResult sum_cols[Cols] = {0 };
121123 TResult sum_cols_ref[Cols] = {0 };
122124
123- matrix_fill (Rows, Cols, (T *)M,
124- [](int i, int j) { return T (1 ) * (i + j); });
125+ matrix_fill (Rows, Cols, (T *)M, [](int i, int j) { return T (1 ) * (i + j); });
125126
126127 matrix_vnni<T>(Rows, Cols, *M, *Mvnni, VF);
127128 big_matrix<T, Rows / VF, Cols * VF> MM ((T *)&Mvnni);
0 commit comments