Skip to content

Commit 7c17f1c

Browse files
modify the test() function interface
1 parent 9fbea06 commit 7c17f1c

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,13 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
357357
template <typename T, typename TResult, size_t vnniFactor, size_t TM, size_t TN,
358358
size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
359359
size_t MCache2, size_t NCache2, size_t KCache2>
360-
void test(size_t matrix_size) {
360+
void test(size_t matrix_size_input) {
361+
#ifdef RUNTIME_DIM
362+
size_t matrix_size = matrix_size_input;
363+
#else
364+
constexpr size_t matrix_size = MATRIX_SIZE;
365+
#endif
366+
361367
assert(matrix_size >= TM && matrix_size >= TK && matrix_size >= TN &&
362368
"invalid matrix size");
363369
assert((matrix_size % TM) == 0 && (matrix_size % TN) == 0 &&
@@ -393,7 +399,7 @@ void test(size_t matrix_size) {
393399
double duration =
394400
joint_matmul<
395401
#if !defined(ARG_DIM) && !defined(RUNTIME_DIM)
396-
MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE, MATRIX_SIZE,
402+
matrix_size, matrix_size, matrix_size, matrix_size,
397403
#endif // ARG_DIM, RUNTIME_DIM
398404
vnniFactor, T, TResult, TM, TN, TK, MCache1, NCache1,
399405
KCache1, MCache2, NCache2, KCache2>
@@ -430,7 +436,7 @@ int main(
430436
#endif //RUNTIME_DIM
431437
) {
432438

433-
size_t matrix_size = MATRIX_SIZE;
439+
size_t matrix_size = -1;
434440
#ifdef RUNTIME_DIM
435441
// Check for command line argument
436442
if (argc == 2) {

0 commit comments

Comments
 (0)