1111template <typename T, size_t Rows, size_t Cols, layout Layout, use Use>
1212class matrix_process ;
1313
14+ template <typename TResult, typename AccessorType>
15+ void reduce_and_accumulate (sub_group sg, size_t sg_size, size_t global_idy,
16+ AccessorType &global_acc, TResult *local_sums,
17+ size_t count) {
18+ for (size_t i = 0 ; i < count; i++) {
19+ local_sums[i] = reduce_over_group (sg, local_sums[i], sycl::plus<>());
20+
21+ // Only the subgroup leader performs the global accumulation
22+ if (global_idy % sg_size == 0 ) {
23+ sycl::atomic_ref<TResult, sycl::memory_order::relaxed,
24+ sycl::memory_scope::device>
25+ aref (global_acc[i]);
26+ aref.fetch_add (local_sums[i]);
27+ }
28+ }
29+ }
30+
1431template <typename T, typename TResult, size_t NUM_ROWS, size_t NUM_COLS,
1532 size_t SROWS, size_t SCOLS, use Use, layout Layout, size_t VF>
1633void matrix_sum (big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> &M,
@@ -32,7 +49,7 @@ void matrix_sum(big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> &M,
3249 {1 , 1 * sg_size}),
3350 [=](nd_item<2 > spmd_item)
3451#ifdef SG_SZ
35- [[intel ::reqd_sub_group_size (SG_SZ)]]
52+ [[sycl ::reqd_sub_group_size (SG_SZ)]]
3653#endif
3754 {
3855 // The submatrix API has to be accessed by all the workitems in a
@@ -83,29 +100,10 @@ void matrix_sum(big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> &M,
83100 });
84101 }
85102
86- for (int i = 0 ; i < NUM_ROWS; i++) {
87- sum_local_rows[i] =
88- reduce_over_group (sg, sum_local_rows[i], sycl::plus<>());
89- // only Groups leader perform the global reduction
90- if (global_idy % sg_size == 0 ) {
91- sycl::atomic_ref<TResult, sycl::memory_order::relaxed,
92- sycl::memory_scope::device>
93- aref (v_rows[i]);
94- aref.fetch_add (sum_local_rows[i]);
95- }
96- }
97-
98- for (int i = 0 ; i < NUM_COLS; i++) {
99- sum_local_cols[i] =
100- reduce_over_group (sg, sum_local_cols[i], sycl::plus<>());
101- // only Groups leader perform the global reduction
102- if (global_idy % sg_size == 0 ) {
103- sycl::atomic_ref<TResult, sycl::memory_order::relaxed,
104- sycl::memory_scope::device>
105- aref (v_cols[i]);
106- aref.fetch_add (sum_local_cols[i]);
107- }
108- }
103+ reduce_and_accumulate (sg, sg_size, global_idy, v_rows,
104+ sum_local_rows, NUM_ROWS);
105+ reduce_and_accumulate (sg, sg_size, global_idy, v_cols,
106+ sum_local_cols, NUM_COLS);
109107 }); // parallel for
110108 }).wait ();
111109}
@@ -124,11 +122,7 @@ void test_get_coord_op() {
124122 TResult sum_cols[Cols] = {0 };
125123 TResult sum_cols_ref[Cols] = {0 };
126124
127- for (int i = 0 ; i < Rows; i++) {
128- for (int j = 0 ; j < Cols; j++) {
129- M[i][j] = i + j;
130- }
131- }
125+ matrix_fill (Rows, Cols, (T *)M, [](int i, int j) { return T (i + j); });
132126
133127 matrix_vnni<T>(Rows, Cols, *M, *Mvnni, VF);
134128 big_matrix<T, Rows / VF, Cols * VF> MM ((T *)&Mvnni);
0 commit comments