99#include < random>
1010#include < sycl/usm.hpp>
1111
12+ #ifdef SLM
13+ #include " slm_utils.hpp"
14+ #endif
15+
1216// number of test iterations
1317constexpr unsigned int testIterations = 100 ;
1418// start recording time after X iterations
@@ -51,6 +55,12 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
5155 std::chrono::high_resolution_clock::now ();
5256
5357 q.submit ([&](handler &h) {
58+ #ifdef SLM
59+ local_accessor<TOperand, 2 > tileA{{MCache2, KCache2}, h};
60+ local_accessor<TOperand, 2 > tileB{
61+ {KCache2 / vnniFactor, NCache2 * vnniFactor}, h};
62+ #endif
63+
5464 h.parallel_for <MatMul<TM, TN, TK>>( // cache layer#1
5565 nd_range<2 >{global, cachelocal},
5666 // loop global
@@ -60,15 +70,16 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
6070 [[intel::reqd_sub_group_size (SG_SZ)]]
6171#endif // SG_SZ
6272 {
73+ // sg::load and sg::store expect decorations to be ON
6374 auto pA =
6475 address_space_cast<sycl::access::address_space::global_space,
65- sycl::access::decorated::no >(A);
76+ sycl::access::decorated::yes >(A);
6677 auto pB =
6778 address_space_cast<sycl::access::address_space::global_space,
68- sycl::access::decorated::no >(B);
79+ sycl::access::decorated::yes >(B);
6980 auto pC =
7081 address_space_cast<sycl::access::address_space::global_space,
71- sycl::access::decorated::no >(C);
82+ sycl::access::decorated::yes >(C);
7283 auto m2 = it.get_group (0 );
7384 auto n2 = it.get_group (1 );
7485 auto m1 = it.get_local_id (0 );
@@ -112,7 +123,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
112123 colsA, layout::row_major,
113124 syclex::properties{syclex::prefetch_hint_L1});
114125
115- #ifdef VNNI
116126 for (int p = 0 ; p < prefDistance; p++)
117127 joint_matrix_prefetch<prefRow, prefCol>(
118128 sg,
@@ -122,15 +132,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
122132 (n2 * NCache2 * vnniFactor + pn1B * prefCol),
123133 colsB * vnniFactor, layout::row_major,
124134 syclex::properties{syclex::prefetch_hint_L1});
125- #else // VNNI
126- for (int p = 0 ; p < prefDistance; p++)
127- joint_matrix_prefetch<prefRow, prefCol>(
128- sg,
129- B + (p * KCache2 + pm1B * prefRow) * colsB + n2 * NCache2 +
130- pn1B * prefCol,
131- colsB, layout::row_major,
132- syclex::properties{syclex::prefetch_hint_L1});
133- #endif // VNNI
134135#endif // PREFETCH
135136
136137 joint_matrix<sub_group, TResult, use::accumulator, TM, TN>
@@ -157,7 +158,16 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
157158 }
158159#endif // MANUAL_UNROLL
159160
161+ #ifdef SLM
162+ constexpr unsigned int SGs =
163+ (MCache2 / MCache1) * (NCache2 / NCache1);
164+ #endif // SLM
160165 for (unsigned int k2 = 0 ; k2 < colsA / KCache2; k2++) {
166+ #ifdef SLM
167+ slm_read_write<colsA, colsB, MCache2, NCache2, KCache2, vnniFactor,
168+ SGs>(pA, pB, tileA, tileB, sg, k2, m2, n2, sgSize);
169+ it.barrier (access::fence_space::local_space);
170+ #endif // SLM
161171 joint_matrix<sub_group, TOperand, use::a, TM, TK, layout::row_major>
162172 tA[MCache1 / TM][KCache2 / KCache1]
163173#ifdef INIT_LIST
@@ -192,6 +202,14 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
192202#else // MANUAL_UNROLL
193203 for (unsigned int m = 0 ; m < MCache1 / TM; m++) {
194204#endif // MANUAL_UNROLL
205+ #ifdef SLM
206+ joint_matrix_load (sg, tA[m][k1],
207+ tileA.template get_multi_ptr <
208+ sycl::access::decorated::no>() +
209+ (m1 * MCache1 + m * TM) * KCache2 +
210+ k1 * TK,
211+ KCache2);
212+ #else // SLM
195213#ifdef OOB
196214 ext::intel::experimental::matrix::joint_matrix_load_checked (
197215 sg, tA[m][k1], pA, colsA, rowsA, colsA,
@@ -203,6 +221,7 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
203221 k * TK,
204222 colsA);
205223#endif // OOB
224+ #endif // SLM
206225#ifdef MANUAL_UNROLL
207226 }); // m
208227#else // MANUAL_UNROLL
@@ -213,32 +232,28 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
213232#else // MANUAL_UNROLL
214233 for (unsigned int n = 0 ; n < NCache1 / TN; n++) {
215234#endif // MANUAL_UNROLL
235+ #ifdef SLM
236+ joint_matrix_load (sg, tB[n][k1],
237+ tileB.template get_multi_ptr <
238+ sycl::access::decorated::no>() +
239+ (k1 * TK / vnniFactor) *
240+ (NCache2 * vnniFactor) +
241+ (n1 * NCache1 + n * TN) * vnniFactor,
242+ NCache2 * vnniFactor);
243+ #else // SLM
216244#ifdef OOB
217- #ifdef VNNI
218245 ext::intel::experimental::matrix::joint_matrix_load_checked (
219246 sg, tB[n][k1], pB, colsB * vnniFactor, rowsB / vnniFactor,
220247 colsB * vnniFactor, k * TK / vnniFactor,
221248 (n2 * NCache2 + n1 * NCache1 + n * TN) * vnniFactor);
222- #else // VNNI
223- ext::intel::experimental::matrix::joint_matrix_load_checked (
224- sg, tB[n][k1], pB, colsB, rowsB, colsB, k * TK,
225- n2 * NCache2 + n1 * NCache1 + n * TN);
226-
227- #endif // VNNI
228249#else // OOB
229- #ifdef VNNI
230250 joint_matrix_load (
231251 sg, tB[n][k1],
232252 pB + (k * TK / vnniFactor) * (colsB * vnniFactor) +
233253 (n2 * NCache2 + n1 * NCache1 + n * TN) * vnniFactor,
234254 colsB * vnniFactor);
235- #else // VNNI
236- joint_matrix_load (sg, tB[n][k1],
237- pB + (k * TK) * (colsB) +
238- (n2 * NCache2 + n1 * NCache1 + n * TN),
239- colsB);
240- #endif // VNNI
241255#endif // OOB
256+ #endif // SLM
242257#ifdef MANUAL_UNROLL
243258 }); // n
244259#else // MANUAL_UNROLL
@@ -266,6 +281,9 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
266281 } // m
267282 } // k1
268283#endif // MANUAL_UNROLL
284+ #ifdef SLM
285+ it.barrier (access::fence_space::local_space);
286+ #endif // SLM
269287#ifdef PREFETCH
270288 auto prefetch_offsetA = (m2 * MCache2 + sgId * prefRow) * colsA +
271289 (k2 + prefDistance) * prefCol;
@@ -275,7 +293,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
275293 sg, A + prefetch_offsetA, colsA, layout::row_major,
276294 syclex::properties{syclex::prefetch_hint_L1});
277295
278- #ifdef VNNI
279296 auto prefetch_offsetB =
280297 ((k2 + prefDistance) * (KCache2 / vnniFactor) +
281298 pm1B * prefRow) *
@@ -287,16 +304,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
287304 sg, B + prefetch_offsetB, colsB * vnniFactor,
288305 layout::row_major,
289306 syclex::properties{syclex::prefetch_hint_L1});
290- #else // VNNI
291- auto prefetch_offsetB =
292- ((k2 + prefDistance) * KCache2 + pm1B * prefRow) * (colsB) +
293- (n2 * NCache2 + pn1B * prefCol);
294- if ((prefetch_offsetB + (prefRow * MATRIX_SIZE) + prefCol) <
295- (MATRIX_SIZE * MATRIX_SIZE))
296- joint_matrix_prefetch<prefRow, prefCol>(
297- sg, B + prefetch_offsetB, colsB, layout::row_major,
298- syclex::properties{syclex::prefetch_hint_L1});
299- #endif // VNNI
300307#endif // PREFETCH
301308 } // for k2
302309#ifdef MANUAL_UNROLL
@@ -411,29 +418,33 @@ int main() {
411418 constexpr size_t NCache2 = 256 ;
412419 constexpr size_t KCache2 = 32 ;
413420
421+ #ifdef VNNI
422+ constexpr unsigned int VnniFactor = 2 ;
423+ #else // VNNI
424+ constexpr unsigned int VnniFactor = 1 ;
425+ #endif // VNNI
426+
414427 for (unsigned int i = 0 ; i < combinations.size (); i++) {
415428 if (combinations[i].nsize == 0 ) { // Intel AMX
416429 constexpr size_t NCache1 = 32 ;
417430 constexpr size_t KCache1 = 32 ;
418-
419- test<bfloat16, float , 2 , /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 32 , MCache1,
420- NCache1, KCache1, MCache2, NCache2, KCache2>();
431+ test<bfloat16, float , VnniFactor, /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 32 ,
432+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
421433 break ;
422434 }
423435
424436 if (combinations[i].nsize == 16 ) { // architecture::intel_gpu_pvc
425437 constexpr size_t NCache1 = 4 * /* TN*/ 16 ;
426438 constexpr size_t KCache1 = 16 ;
427-
428- test<bfloat16, float , 2 , /* TM*/ 8 , /* TN*/ 16 , /* TK*/ 16 , MCache1, NCache1,
429- KCache1, MCache2, NCache2, KCache2>();
439+ test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 16 , /* TK*/ 16 , MCache1,
440+ NCache1, KCache1, MCache2, NCache2, KCache2>();
430441#if (!defined(SG_SZ) || SG_SZ != 32)
431442 // These combination are not currently supported for subgroup size = 32 in
432443 // IGC
433- test<bfloat16, float , 2 , /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 16 , MCache1 ,
434- NCache1, KCache1, MCache2, NCache2, KCache2>();
435- test<bfloat16, float , 2 , /* TM*/ 32 , /* TN*/ 64 , /* TK*/ 16 , MCache1 ,
436- NCache1, KCache1, MCache2, NCache2, KCache2>();
444+ test<bfloat16, float , VnniFactor , /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 16 ,
445+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
446+ test<bfloat16, float , VnniFactor , /* TM*/ 32 , /* TN*/ 64 , /* TK*/ 16 ,
447+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
437448#endif
438449 break ;
439450 }
@@ -442,9 +453,10 @@ int main() {
442453 constexpr size_t NCache1 = 4 * /* TN*/ 8 ;
443454 constexpr size_t KCache1 = 16 ;
444455
445- test<bfloat16, float , 2 , /* TM*/ 8 , /* TN*/ 8 , /* TK*/ 16 , MCache1, NCache1,
446- KCache1, MCache2, NCache2, KCache2>();
447- // test<bfloat16, float, 2, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16, MCache1,
456+ test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 8 , /* TK*/ 16 , MCache1,
457+ NCache1, KCache1, MCache2, NCache2, KCache2>();
458+ // test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16,
459+ // MCache1,
448460 // NCache1, KCache1, MCache2, NCache2, KCache2>();
449461 break ;
450462 }
0 commit comments