@@ -37,17 +37,16 @@ static constexpr void manually_unroll_loop(F &&f) {
3737template <size_t TM, size_t TN, size_t TK> class MatMul ;
3838
3939template <
40- #ifndef ARG_DIM
40+ #if !defined( ARG_DIM) && !defined(RUNTIME_DIM)
4141 size_t rowsA, size_t colsA, size_t rowsB, size_t colsB,
42- #endif // ARG_DIM
42+ #endif // ARG_DIM, RUNTIME_DIM
4343 size_t vnniFactor, typename TOperand, typename TResult, size_t TM,
4444 size_t TN, size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
4545 size_t MCache2, size_t NCache2, size_t KCache2>
46-
4746double joint_matmul (TOperand *A, TOperand *B, TResult *C, queue &q, int i
48- #ifdef ARG_DIM
47+ #if defined( ARG_DIM) || defined(RUNTIME_DIM)
4948 , size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
50- #endif // ARG_DIM
49+ #endif // ARG_DIM, RUNTIME_DIM
5150 ) {
5251
5352 size_t sgSize = get_sg_size<MatMul<TM, TN, TK>>(q);
@@ -296,8 +295,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
296295#ifdef PREFETCH
297296 auto prefetch_offsetA = (m2 * MCache2 + sgId * prefRow) * colsA +
298297 (k2 + prefDistance) * prefCol;
299- if ((prefetch_offsetA + (prefRow * MATRIX_SIZE ) + prefCol) <
300- (MATRIX_SIZE * MATRIX_SIZE ))
298+ if ((prefetch_offsetA + (prefRow * colsA ) + prefCol) <
299+ (rowsA * colsA ))
301300 joint_matrix_prefetch<prefRow, prefCol>(
302301 sg, A + prefetch_offsetA, colsA, layout::row_major,
303302 syclex::properties{syclex::prefetch_hint_L1});
@@ -307,8 +306,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
307306 pm1B * prefRow) *
308307 (colsB)*vnniFactor +
309308 (n2 * NCache2 * vnniFactor + pn1B * prefCol);
310- if ((prefetch_offsetB + (prefRow * MATRIX_SIZE * vnniFactor) +
311- prefCol) < (MATRIX_SIZE * MATRIX_SIZE ))
309+ if ((prefetch_offsetB + (prefRow * colsA * vnniFactor) +
310+ prefCol) < (rowsA * colsA ))
312311 joint_matrix_prefetch<prefRow, prefCol>(
313312 sg, B + prefetch_offsetB, colsB * vnniFactor,
314313 layout::row_major,
@@ -355,35 +354,34 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
355354 return duration.count ();
356355}
357356
358- #ifndef EXCLUDE_MAIN_TEST
359357template <typename T, typename TResult, size_t vnniFactor, size_t TM, size_t TN,
360358 size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
361359 size_t MCache2, size_t NCache2, size_t KCache2>
362- void test () {
363- assert (MATRIX_SIZE >= TM && MATRIX_SIZE >= TK && MATRIX_SIZE >= TN &&
360+ void test (size_t matrix_size ) {
361+ assert (matrix_size >= TM && matrix_size >= TK && matrix_size >= TN &&
364362 " invalid matrix size" );
365- assert ((MATRIX_SIZE % TM) == 0 && (MATRIX_SIZE % TN) == 0 &&
366- (MATRIX_SIZE % TK) == 0 &&
363+ assert ((matrix_size % TM) == 0 && (matrix_size % TN) == 0 &&
364+ (matrix_size % TK) == 0 &&
367365 " invalid matrix size detected: not a multiple of <TM,TN,TK>" );
368366
369367 std::cout << " Testing: " << TM << " x " << TN << " x " << TK
370368 << " [TM x TN x TK]" << std::endl;
371369
372370 queue q;
373- T *A = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
374- T *B = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
375- TResult *C = malloc_shared<TResult>(MATRIX_SIZE * MATRIX_SIZE , q);
376- TResult *refC = malloc_shared<TResult>(MATRIX_SIZE * MATRIX_SIZE , q);
371+ T *A = malloc_shared<T>(matrix_size * matrix_size , q);
372+ T *B = malloc_shared<T>(matrix_size * matrix_size , q);
373+ TResult *C = malloc_shared<TResult>(matrix_size * matrix_size , q);
374+ TResult *refC = malloc_shared<TResult>(matrix_size * matrix_size , q);
377375
378- matrix_rand<T>(MATRIX_SIZE, MATRIX_SIZE , A, T (1 ));
379- matrix_rand<T>(MATRIX_SIZE, MATRIX_SIZE , B, T (1 ));
376+ matrix_rand<T>(matrix_size, matrix_size , A, T (1 ));
377+ matrix_rand<T>(matrix_size, matrix_size , B, T (1 ));
380378
381- matrix_multiply_ref<T, T, TResult, 1 >(A, B, refC, MATRIX_SIZE, MATRIX_SIZE ,
382- MATRIX_SIZE );
379+ matrix_multiply_ref<T, T, TResult, 1 >(A, B, refC, matrix_size, matrix_size ,
380+ matrix_size );
383381
384382#ifdef VNNI
385- T *vnniB = malloc_shared<T>(MATRIX_SIZE * MATRIX_SIZE , q);
386- matrix_vnni<T>(MATRIX_SIZE, MATRIX_SIZE , B, vnniB, vnniFactor);
383+ T *vnniB = malloc_shared<T>(matrix_size * matrix_size , q);
384+ matrix_vnni<T>(matrix_size, matrix_size , B, vnniB, vnniFactor);
387385 free (B, q);
388386 B = vnniB;
389387#endif
@@ -394,30 +392,30 @@ void test() {
394392
395393 double duration =
396394 joint_matmul<
397- #ifndef ARG_DIM
395+ #if !defined( ARG_DIM) && !defined(RUNTIME_DIM)
398396 MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE,
399- #endif // ARG_DIM
397+ #endif // ARG_DIM, RUNTIME_DIM
400398 vnniFactor, T, TResult, TM, TN, TK, MCache1, NCache1,
401399 KCache1, MCache2, NCache2, KCache2>
402400 (A, B, C, q, i
403- #ifdef ARG_DIM
404- , MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE
405- #endif // ARG_DIM
401+ #if defined( ARG_DIM) || defined(RUNTIME_DIM)
402+ , matrix_size, matrix_size, matrix_size, matrix_size
403+ #endif // ARG_DIM, RUNTIME_DIM
406404 );
407405
408406 if (i >= recordThresh) {
409407 totalDuration += duration;
410408 }
411409 }
412410
413- assert (matrix_compare (MATRIX_SIZE, MATRIX_SIZE , C, refC));
411+ assert (matrix_compare (matrix_size, matrix_size , C, refC));
414412
415413 double msecPerMatrixMul =
416414 totalDuration / static_cast <double >(testIterations - recordThresh);
417- double gflops = (2 .f * MATRIX_SIZE * MATRIX_SIZE * MATRIX_SIZE * 1 .0e-9f ) /
415+ double gflops = (2 .f * matrix_size * matrix_size * matrix_size * 1 .0e-9f ) /
418416 (msecPerMatrixMul / 1000 .f );
419417
420- std::cout << " DONE for size " << MATRIX_SIZE << std::endl;
418+ std::cout << " DONE for size " << matrix_size << std::endl;
421419 std::cout << " GOPS is " << gflops << " Gop/s" << std::endl;
422420
423421 free (A, q);
@@ -426,7 +424,23 @@ void test() {
426424 free (refC, q);
427425}
428426
429- int main () {
427+ int main (
428+ #ifdef RUNTIME_DIM
429+ int argc, char *argv[]
430+ #endif // RUNTIME_DIM
431+ ) {
432+
433+ size_t matrix_size = MATRIX_SIZE;
434+ #ifdef RUNTIME_DIM
435+ // Check for command line argument
436+ if (argc == 2 ) {
437+ matrix_size = std::stoul (argv[1 ]);
438+ } else {
439+ std::cerr << " Usage: ./program matrix_size\n " ;
440+ return 1 ; // Error if no argument
441+ }
442+ #endif // RUNTIME_DIM
443+
430444 queue q;
431445 std::vector<combination> combinations =
432446 q.get_device ()
@@ -449,22 +463,22 @@ int main() {
449463 constexpr size_t NCache1 = 32 ;
450464 constexpr size_t KCache1 = 32 ;
451465 test<bfloat16, float , VnniFactor, /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 32 ,
452- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
466+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
453467 break ;
454468 }
455469
456470 if (combinations[i].nsize == 16 ) { // architecture::intel_gpu_pvc
457471 constexpr size_t NCache1 = 4 * /* TN*/ 16 ;
458472 constexpr size_t KCache1 = 16 ;
459473 test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 16 , /* TK*/ 16 , MCache1,
460- NCache1, KCache1, MCache2, NCache2, KCache2>();
474+ NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
461475#if (!defined(SG_SZ) || SG_SZ != 32)
462476 // These combination are not currently supported for subgroup size = 32 in
463477 // IGC
464478 test<bfloat16, float , VnniFactor, /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 16 ,
465- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
479+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
466480 test<bfloat16, float , VnniFactor, /* TM*/ 32 , /* TN*/ 64 , /* TK*/ 16 ,
467- MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
481+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size );
468482#endif
469483 break ;
470484 }
@@ -474,13 +488,11 @@ int main() {
474488 constexpr size_t KCache1 = 16 ;
475489
476490 test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 8 , /* TK*/ 16 , MCache1,
477- NCache1, KCache1, MCache2, NCache2, KCache2>();
478- // test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16,
479- // MCache1,
480- // NCache1, KCache1, MCache2, NCache2, KCache2>();
491+ NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
492+ // test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16, MCache1,
493+ // NCache1, KCache1, MCache2, NCache2, KCache2>(matrix_size);
481494 break ;
482495 }
483496 }
484497 return 0 ;
485498}
486- #endif // EXCLUDE_MAIN_TEST
0 commit comments