@@ -167,28 +167,43 @@ void test(queue &q) {
167167 sub_c;
168168 joint_matrix<sycl::sub_group, Td, use::accumulator, M, N> sub_d;
169169 auto stride_C = layout_C == layout::row_major ? Big_N : Big_M;
170+ #ifdef OFFSET
171+
172+ joint_matrix_load (
173+ sg, sub_c, accC.template get_multi_ptr <access::decorated::no>(),
174+ m * M, n * N, stride_C, layout_C);
175+ #else
170176 auto load_stride_C = layout_C == layout::row_major
171177 ? (m * M) * Big_N + n * N
172178 : (m * M) + n * N * Big_M;
173-
174179 joint_matrix_load (
175180 sg, sub_c,
176181 accC.template get_multi_ptr <access::decorated::no>() +
177182 load_stride_C,
178183 stride_C, layout_C);
179-
184+ # endif
180185 auto stride_A = layout_A == layout::row_major ? Big_K : Big_M;
181186 auto stride_B = layout_B == layout::row_major ? Big_N : Big_K;
182187
183188 // k = row/col id of current submatrix of BIG A/B matrices
184189 for (int k = 0 ; k < Sub_Tiles_K; k++) {
190+ #ifdef OFFSET
191+ joint_matrix_load (
192+ sg, sub_a,
193+ accA.template get_multi_ptr <access::decorated::no>(), m * M,
194+ k * K, stride_A);
195+
196+ joint_matrix_load (
197+ sg, sub_b,
198+ accB.template get_multi_ptr <access::decorated::no>(), k * K,
199+ n * N, load_stride_B, stride_B);
200+ #else
185201 auto load_stride_A = layout_A == layout::row_major
186202 ? (k * K) + (m * M * Big_K)
187203 : (k * K * Big_M) + (m * M);
188204 auto load_stride_B = layout_B == layout::row_major
189205 ? (k * K * Big_N) + (n * N)
190206 : (k * K) + (n * N * Big_K);
191-
192207 joint_matrix_load (
193208 sg, sub_a,
194209 accA.template get_multi_ptr <access::decorated::no>() +
@@ -200,7 +215,7 @@ void test(queue &q) {
200215 accB.template get_multi_ptr <access::decorated::no>() +
201216 load_stride_B,
202217 stride_B);
203-
218+ # endif
204219 // round values to correct precision if using tf32
205220 if constexpr (std::is_same<T3, precision::tf32>::value) {
206221 auto round_lambda = [](auto &x) { x = round_to_tf32 (x); };
@@ -211,11 +226,17 @@ void test(queue &q) {
211226 joint_matrix_mad (sg, sub_d, sub_a, sub_b, sub_c);
212227 joint_matrix_copy (sg, sub_d, sub_c);
213228 }
229+ #ifdef OFFSET
214230 joint_matrix_store (
231+ sg, sub_d, accD.template get_multi_ptr <access::decorated::no>(),
232+ m * M, n * N, stride_C, layout_C);
233+ #else
234+ joint_matrix_store (
215235 sg, sub_d,
216236 accD.template get_multi_ptr <access::decorated::no>() +
217237 load_stride_C,
218238 stride_C, layout_C);
239+ #endif
219240 });
220241 });
221242 q.wait ();
0 commit comments