diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB.cpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB.cpp index 82bedf7043e9d..2519b0fdb4c79 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB.cpp @@ -17,7 +17,4 @@ // XFAIL-TRACKER: GSD-5768 #include "common.hpp" - -constexpr size_t TN = 16; - #include "joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp" diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index d8f5e45474a77..bab88721fb1a9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -6,8 +6,9 @@ // //===----------------------------------------------------------------------===// -#define TM 8 -#define TK 16 +constexpr size_t TM = 8; +constexpr size_t TN = 16; +constexpr size_t TK = 16; template void matrix_multiply(big_matrix &C, big_matrix &A, @@ -43,7 +44,6 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; - // For B, we assume B has been already VNNIed. joint_matrix sub_b; joint_matrix sub_c;