@@ -357,7 +357,13 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
357357template <typename T, typename TResult, size_t vnniFactor, size_t TM, size_t TN,
358358 size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
359359 size_t MCache2, size_t NCache2, size_t KCache2>
360- void test (size_t matrix_size) {
360+ void test (size_t matrix_size_input) {
361+ #ifdef RUNTIME_DIM
362+ size_t matrix_size = matrix_size_input;
363+ #else
364+ constexpr size_t matrix_size = MATRIX_SIZE;
365+ #endif
366+
361367 assert (matrix_size >= TM && matrix_size >= TK && matrix_size >= TN &&
362368 " invalid matrix size" );
363369 assert ((matrix_size % TM) == 0 && (matrix_size % TN) == 0 &&
@@ -393,7 +399,7 @@ void test(size_t matrix_size) {
393399 double duration =
394400 joint_matmul<
395401#if !defined(ARG_DIM) && !defined(RUNTIME_DIM)
396- MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE ,
402+ matrix_size, matrix_size, matrix_size, matrix_size ,
397403#endif // ARG_DIM, RUNTIME_DIM
398404 vnniFactor, T, TResult, TM, TN, TK, MCache1, NCache1,
399405 KCache1, MCache2, NCache2, KCache2>
@@ -430,7 +436,7 @@ int main(
430436#endif // RUNTIME_DIM
431437 ) {
432438
433- size_t matrix_size = MATRIX_SIZE ;
439+ size_t matrix_size = - 1 ;
434440#ifdef RUNTIME_DIM
435441 // Check for command line argument
436442 if (argc == 2 ) {
0 commit comments