diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 5f898101de031..cab0b2f599575 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -118,35 +118,39 @@ joint_matrix_apply(Group sg, joint_matrix &jm, return; } -template inline __SYCL_ALWAYS_INLINE void -joint_matrix_apply(Group sg, joint_matrix &jmsrc, - joint_matrix &jmdest, +joint_matrix_apply(Group sg, joint_matrix &jm0, + joint_matrix &jm1, F &&lambda) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__) std::ignore = sg; - for (int i = 0; i < jmsrc.matrix_impl.wi_marray.size(); i++) { - lambda(jmsrc.matrix_impl.wi_marray[i], jmdest.matrix_impl.wi_marray[i]); + for (int i = 0; i < jm0.matrix_impl.wi_marray.size(); i++) { + lambda(jm0.matrix_impl.wi_marray[i], jm1.matrix_impl.wi_marray[i]); } #else // NVPTX - using storage_element_type = + using storage_element_type0 = typename oneapi::detail::jm_type_interpretation_helper_trait< - T>::storage_element_type; - auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jmsrc); - auto wi_data_d = sycl::ext::oneapi::detail::get_wi_data(sg, jmdest); - for (int i = 0; i < wi_data_c.length(); i++) { - storage_element_type elementsrc = wi_data_c[i]; - storage_element_type elementdest = wi_data_d[i]; - lambda(elementsrc, elementdest); - wi_data_d[i] = elementdest; + T0>::storage_element_type; + using storage_element_type1 = + typename oneapi::detail::jm_type_interpretation_helper_trait< + T1>::storage_element_type; + auto wi_data_0 = sycl::ext::oneapi::detail::get_wi_data(sg, jm0); + auto wi_data_1 = sycl::ext::oneapi::detail::get_wi_data(sg, jm1); + for (int i = 0; i < wi_data_0.length(); i++) { + storage_element_type0 element0 = wi_data_0[i]; + storage_element_type1 element1 = wi_data_1[i]; + lambda(element0, element1); + wi_data_0[i] = element0; + wi_data_1[i] = element1; } #endif #else std::ignore = sg; - std::ignore = jmsrc; - std::ignore = jmdest; + std::ignore = jm0; + std::ignore = jm1; std::ignore = lambda; throw exception(make_error_code(errc::runtime), "joint matrix is not supported on host."); diff --git a/sycl/test-e2e/Matrix/common.hpp b/sycl/test-e2e/Matrix/common.hpp index 90f5508d97cf8..db184466649d5 100644 --- a/sycl/test-e2e/Matrix/common.hpp +++ b/sycl/test-e2e/Matrix/common.hpp @@ -156,6 +156,13 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) { } } +template +void matrix_apply(unsigned int rows, unsigned int cols, T *mat, F op) { + for (unsigned int i = 0; i < rows; i++) + for (unsigned int j = 0; j < cols; j++) + mat[i * cols + j] = op(mat[i * cols + j]); +} + template bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) { for (int i = 0; i < rows; i++) { @@ -173,7 +180,7 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) { << ", Epsilon: " << FLOAT_EPSILON << "\n"; return false; } - } else if constexpr (exact || std::is_same_v) { + } else if constexpr (exact || std::is_integral_v) { if (src[i * cols + j] != ref[i * cols + j]) { std::cout << "Incorrect result in matrix." << "i: " << i << ", j: " << j diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_two_matrices_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_two_matrices_impl.hpp index a88b0ca55416e..e8fdf866e641a 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_two_matrices_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_two_matrices_impl.hpp @@ -7,29 +7,26 @@ //===----------------------------------------------------------------------===// #include -template -bool apply_verify(Tc *C, Tc *D, Ta *A, Ta *Ar) { - for (size_t i = 0; i < M; i++) - for (size_t j = 0; j < N; j++) { - Tc diffc = D[i * N + j] - C[i * N + j] * 2; - Ta diffa = Ar[i * N + j] - (A[i * N + j] + 42); - if constexpr (std::is_same_v) { - if (std::fabs(diffc) > FLOAT_EPSILON || - std::fabs(diffa) > FLOAT_EPSILON || std::isnan(C[i * N + j]) || - std::isnan(A[i * N + j])) { - return false; - } - } else { - if (std::abs(diffc) > 0 || std::abs(diffa) > 0) { - return false; - } - } - } - return true; +template T mul2(T x) { return x * 2; } + +template T add5(T x) { return x + 5; } + +template +bool apply_verify(Tc *C, Tc *D, Tc *ref) { + Tc *refcopy = (Tc *)std::malloc(M * N * sizeof(Tc)); + memcpy(refcopy, ref, M * N * sizeof(Tc)); + matrix_apply(M, N, ref, mul2); + bool res = matrix_compare(M, N, D, ref); + + matrix_apply(M, N, refcopy, add5); + res &= matrix_compare(M, N, C, refcopy); + return res; } + template -bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) { +bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, Tc *Cref, Ta *Aref, + queue q) { size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; @@ -70,22 +67,33 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) { joint_matrix_load( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN, N, layout::row_major); - joint_matrix_apply(sg, sub_c, sub_d, - [](const Tc &x, Tc &y) { y = x * 2; }); + joint_matrix_apply(sg, sub_c, sub_d, [](Tc &x, Tc &y) { + y = mul2(x); + x = add5(x); + }); joint_matrix_store( sg, sub_d, pD + (sg_startx * TM) * N + sg_starty / sg_size * TN, N, layout::row_major); + joint_matrix_store( + sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN, + N, layout::row_major); joint_matrix_load( sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK, K); - joint_matrix_apply(sg, sub_a, sub_ar, - [](const Ta &x, Ta &y) { y = x + 42; }); + joint_matrix_apply(sg, sub_a, sub_ar, [](Ta &x, Ta &y) { + y = mul2(x); + x = add5(x); + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_ar, pAr + (sg_startx * TM) * K + sg_starty / sg_size * TK, K); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK, + K); }); // parallel for }).wait(); - return apply_verify(C, D, A, Ar); + return apply_verify(C, D, Cref) && + apply_verify(A, Ar, Aref); } template (M * N, q); + Ta *Aref = malloc_shared(M * K, q); Tc *C = malloc_shared(M * N, q); Tc *D = malloc_shared(M * N, q); Ta *A = malloc_shared(M * K, q); Ta *Ar = malloc_shared(M * K, q); - matrix_rand(M, N, (Tc *)C, (Tc)100); - matrix_rand(M, K, (Ta *)A, (Ta)100); + matrix_rand(M, N, (Tc *)Cref, (Tc)100); + matrix_rand(M, K, (Ta *)Aref, (Ta)100); + matrix_copy(M, N, Cref, C); + matrix_copy(M, K, Aref, A); bool res = apply_two_matrices( - C, D, A, Ar, q); + C, D, A, Ar, Cref, Aref, q); if constexpr (std::is_same_v) std::cout << "bfloat16 " << TM << "x" << TN << "x" << TK << ": " @@ -117,6 +129,8 @@ bool test() { free(D, q); free(A, q); free(Ar, q); + free(Cref, q); + free(Aref, q); return res; }