@@ -36,11 +36,19 @@ static constexpr void manually_unroll_loop(F &&f) {
3636
3737template <size_t TM, size_t TN, size_t TK> class MatMul ;
3838
39- template <size_t rowsA, size_t colsA, size_t rowsB, size_t colsB,
39+ template <
40+ #if !defined(ARG_DIM) && !defined(RUNTIME_DIM)
41+ size_t rowsA, size_t colsA, size_t rowsB, size_t colsB,
42+ #endif // ARG_DIM, RUNTIME_DIM
4043 size_t vnniFactor, typename TOperand, typename TResult, size_t TM,
4144 size_t TN, size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
4245 size_t MCache2, size_t NCache2, size_t KCache2>
43- double joint_matmul (TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
46+ double joint_matmul (TOperand *A, TOperand *B, TResult *C, queue &q, int i
47+ #if defined(ARG_DIM) || defined(RUNTIME_DIM)
48+ , size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
49+ #endif // ARG_DIM, RUNTIME_DIM
50+ ) {
51+
4452 size_t sgSize = get_sg_size<MatMul<TM, TN, TK>>(q);
4553 range<2 > global{rowsA / MCache1, (colsB / NCache1) * sgSize};
4654 range<2 > cachelocal{MCache2 / MCache1, NCache2 / NCache1 * sgSize};
@@ -287,8 +295,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
287295#ifdef PREFETCH
288296 auto prefetch_offsetA = (m2 * MCache2 + sgId * prefRow) * colsA +
289297 (k2 + prefDistance) * prefCol;
290- if ((prefetch_offsetA + (prefRow * MATRIX_SIZE ) + prefCol) <
291- (MATRIX_SIZE * MATRIX_SIZE ))
298+ if ((prefetch_offsetA + (prefRow * colsA ) + prefCol) <
299+ (rowsA * colsA ))
292300 joint_matrix_prefetch<prefRow, prefCol>(
293301 sg, A + prefetch_offsetA, colsA, layout::row_major,
294302 syclex::properties{syclex::prefetch_hint_L1});
@@ -298,8 +306,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
298306 pm1B * prefRow) *
299307 (colsB)*vnniFactor +
300308 (n2 * NCache2 * vnniFactor + pn1B * prefCol);
301- if ((prefetch_offsetB + (prefRow * MATRIX_SIZE * vnniFactor) +
302- prefCol) < (MATRIX_SIZE * MATRIX_SIZE ))
309+ if ((prefetch_offsetB + (prefRow * colsB * vnniFactor) +
310+ prefCol) < (rowsB * colsB ))
303311 joint_matrix_prefetch<prefRow, prefCol>(
304312 sg, B + prefetch_offsetB, colsB * vnniFactor,
305313 layout::row_major,
@@ -349,31 +357,37 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
349357template <typename T, typename TResult, size_t vnniFactor, size_t TM, size_t TN,
350358 size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
351359 size_t MCache2, size_t NCache2, size_t KCache2>
352- void test () {
353- assert (MATRIX_SIZE >= TM && MATRIX_SIZE >= TK && MATRIX_SIZE >= TN &&
360+ void test (size_t matrix_size_input) {
361+ #ifdef RUNTIME_DIM
362+ size_t matrix_size = matrix_size_input;
363+ #else
364+ constexpr size_t matrix_size = MATRIX_SIZE;
365+ #endif // RUNTIME_DIM
366+
367+ assert (matrix_size >= TM && matrix_size >= TK && matrix_size >= TN &&
354368 " invalid matrix size" );
355- assert ((MATRIX_SIZE % TM) == 0 && (MATRIX_SIZE % TN) == 0 &&
356- (MATRIX_SIZE % TK) == 0 &&
369+ assert ((matrix_size % TM) == 0 && (matrix_size % TN) == 0 &&
370+ (matrix_size % TK) == 0 &&
357371 " invalid matrix size detected: not a multiple of <TM,TN,TK>" );
358372
359373 std::cout << " Testing: " << TM << " x " << TN << " x " << TK
360374 << " [TM x TN x TK]" << std::endl;
361375
362376 queue q;
363- T *A = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
364- T *B = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
365- TResult *C = malloc_shared<TResult>(MATRIX_SIZE * MATRIX_SIZE , q);
366- TResult *refC = malloc_shared<TResult>(MATRIX_SIZE * MATRIX_SIZE , q);
377+ T *A = malloc_shared<T>(matrix_size * matrix_size , q);
378+ T *B = malloc_shared<T>(matrix_size * matrix_size , q);
379+ TResult *C = malloc_shared<TResult>(matrix_size * matrix_size , q);
380+ TResult *refC = malloc_shared<TResult>(matrix_size * matrix_size , q);
367381
368- matrix_rand<T>(MATRIX_SIZE, MATRIX_SIZE , A, T (1 ));
369- matrix_rand<T>(MATRIX_SIZE, MATRIX_SIZE , B, T (1 ));
382+ matrix_rand<T>(matrix_size, matrix_size , A, T (1 ));
383+ matrix_rand<T>(matrix_size, matrix_size , B, T (1 ));
370384
371- matrix_multiply_ref<T, T, TResult, 1 >(A, B, refC, MATRIX_SIZE, MATRIX_SIZE ,
372- MATRIX_SIZE );
385+ matrix_multiply_ref<T, T, TResult, 1 >(A, B, refC, matrix_size, matrix_size ,
386+ matrix_size );
373387
374388#ifdef VNNI
375- T *vnniB = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
376- matrix_vnni<T>(MATRIX_SIZE, MATRIX_SIZE , B, vnniB, vnniFactor);
389+ T *vnniB = malloc_shared<T>(matrix_size * matrix_size , q);
390+ matrix_vnni<T>(matrix_size, matrix_size , B, vnniB, vnniFactor);
377391 free (B, q);
378392 B = vnniB;
379393#endif
@@ -382,22 +396,31 @@ void test() {
382396 double totalDuration = 0 ;
383397 for (unsigned int i = 0 ; i < testIterations; i++) {
384398 double duration =
385- joint_matmul<MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE,
386- vnniFactor, T, TResult, TM, TN, TK, MCache1, NCache1,
387- KCache1, MCache2, NCache2, KCache2>(A, B, C, q, i);
399+ joint_matmul<
400+ #if !defined(ARG_DIM) && !defined(RUNTIME_DIM)
401+ matrix_size, matrix_size, matrix_size, matrix_size,
402+ #endif // ARG_DIM, RUNTIME_DIM
403+ vnniFactor, T, TResult, TM, TN, TK, MCache1, NCache1,
404+ KCache1, MCache2, NCache2, KCache2>
405+ (A, B, C, q, i
406+ #if defined(ARG_DIM) || defined(RUNTIME_DIM)
407+ , matrix_size, matrix_size, matrix_size, matrix_size
408+ #endif // ARG_DIM, RUNTIME_DIM
409+ );
410+
388411 if (i >= recordThresh) {
389412 totalDuration += duration;
390413 }
391414 }
392415
393- assert (matrix_compare (MATRIX_SIZE, MATRIX_SIZE , C, refC));
416+ assert (matrix_compare (matrix_size, matrix_size , C, refC));
394417
395418 double msecPerMatrixMul =
396419 totalDuration / static_cast <double >(testIterations - recordThresh);
397- double gflops = (2 .f * MATRIX_SIZE * MATRIX_SIZE * MATRIX_SIZE * 1 .0e-9f ) /
420+ double gflops = (2 .f * matrix_size * matrix_size * matrix_size * 1 .0e-9f ) /
398421 (msecPerMatrixMul / 1000 .f );
399422
400- std::cout << " DONE for size " << MATRIX_SIZE << std::endl;
423+ std::cout << " DONE for size " << matrix_size << std::endl;
401424 std::cout << " GOPS is " << gflops << " Gop/s" << std::endl;
402425
403426 free (A, q);
@@ -406,7 +429,22 @@ void test() {
406429 free (refC, q);
407430}
408431
409- int main () {
432+ int main (
433+ #ifdef RUNTIME_DIM
434+ int argc, char *argv[]
435+ #endif // RUNTIME_DIM
436+ ) {
437+
438+ size_t matrix_size = -1 ;
439+ #ifdef RUNTIME_DIM
440+ if (argc == 2 ) {
441+ matrix_size = std::stoul (argv[1 ]);
442+ } else {
443+ std::cerr << " Usage: ./program matrix_size\n " ;
444+ return 1 ; // Error if no argument
445+ }
446+ #endif // RUNTIME_DIM
447+
410448 queue q;
411449 std::vector<combination> combinations =
412450 q.get_device ()
@@ -429,22 +467,22 @@ int main() {
429467 constexpr size_t NCache1 = 32 ;
430468 constexpr size_t KCache1 = 32 ;
431469 test<bfloat16, float , VnniFactor, /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 32 ,
432- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
470+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
433471 break ;
434472 }
435473
436474 if (combinations[i].nsize == 16 ) { // architecture::intel_gpu_pvc
437475 constexpr size_t NCache1 = 4 * /* TN*/ 16 ;
438476 constexpr size_t KCache1 = 16 ;
439477 test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 16 , /* TK*/ 16 , MCache1,
440- NCache1, KCache1, MCache2, NCache2, KCache2>();
478+ NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
441479#if (!defined(SG_SZ) || SG_SZ != 32)
442480 // These combination are not currently supported for subgroup size = 32 in
443481 // IGC
444482 test<bfloat16, float , VnniFactor, /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 16 ,
445- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
483+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
446484 test<bfloat16, float , VnniFactor, /* TM*/ 32 , /* TN*/ 64 , /* TK*/ 16 ,
447- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
485+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
448486#endif
449487 break ;
450488 }
@@ -454,10 +492,9 @@ int main() {
454492 constexpr size_t KCache1 = 16 ;
455493
456494 test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 8 , /* TK*/ 16 , MCache1,
457- NCache1, KCache1, MCache2, NCache2, KCache2>();
458- // test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16,
459- // MCache1,
460- // NCache1, KCache1, MCache2, NCache2, KCache2>();
495+ NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
496+ // test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16, MCache1,
497+ // NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
461498 break ;
462499 }
463500 }
0 commit comments