77// ===----------------------------------------------------------------------===//
88#include < sycl/usm.hpp>
99
10- template <typename Tc, typename Ta, size_t M, size_t N>
11- bool apply_verify (Tc *C, Tc *D, Ta *A, Ta *Ar) {
12- for (size_t i = 0 ; i < M; i++)
13- for (size_t j = 0 ; j < N; j++) {
14- Tc diffc = D[i * N + j] - C[i * N + j] * 2 ;
15- Ta diffa = Ar[i * N + j] - (A[i * N + j] + 42 );
16- if constexpr (std::is_same_v<Ta, bfloat16>) {
17- if (std::fabs (diffc) > FLOAT_EPSILON ||
18- std::fabs (diffa) > FLOAT_EPSILON || std::isnan (C[i * N + j]) ||
19- std::isnan (A[i * N + j])) {
20- return false ;
21- }
22- } else {
23- if (std::abs (diffc) > 0 || std::abs (diffa) > 0 ) {
24- return false ;
25- }
26- }
27- }
28- return true ;
10+ template <typename T> T mul2 (T x) { return x * 2 ; }
11+
12+ template <typename T> T add5 (T x) { return x + 5 ; }
13+
14+ template <typename Tc, size_t M, size_t N>
15+ bool apply_verify (Tc *C, Tc *D, Tc *ref) {
16+ Tc *refcopy = (Tc *)std::malloc (M * N * sizeof (Tc));
17+ memcpy (refcopy, ref, M * N * sizeof (Tc));
18+ matrix_apply (M, N, ref, mul2<Tc>);
19+ bool res = matrix_compare (M, N, D, ref);
20+
21+ matrix_apply (M, N, refcopy, add5<Tc>);
22+ res &= matrix_compare (M, N, C, refcopy);
23+ return res;
2924}
25+
3026template <typename Tc, typename Ta, size_t TM, size_t TN, size_t TK, size_t M,
3127 size_t N, size_t K, class kernel_name >
32- bool apply_two_matrices (Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
28+ bool apply_two_matrices (Tc *C, Tc *D, Ta *A, Ta *Ar, Tc *Cref, Ta *Aref,
29+ queue q) {
3330 size_t NDRangeM = M / TM;
3431 size_t NDRangeN = N / TN;
3532
@@ -70,22 +67,33 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
7067 joint_matrix_load (
7168 sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
7269 N, layout::row_major);
73- joint_matrix_apply (sg, sub_c, sub_d,
74- [](const Tc &x, Tc &y) { y = x * 2 ; });
70+ joint_matrix_apply (sg, sub_c, sub_d, [](Tc &x, Tc &y) {
71+ y = mul2 (x);
72+ x = add5 (x);
73+ });
7574 joint_matrix_store (
7675 sg, sub_d, pD + (sg_startx * TM) * N + sg_starty / sg_size * TN,
7776 N, layout::row_major);
77+ joint_matrix_store (
78+ sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
79+ N, layout::row_major);
7880 joint_matrix_load (
7981 sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
8082 K);
81- joint_matrix_apply (sg, sub_a, sub_ar,
82- [](const Ta &x, Ta &y) { y = x + 42 ; });
83+ joint_matrix_apply (sg, sub_a, sub_ar, [](Ta &x, Ta &y) {
84+ y = mul2 (x);
85+ x = add5 (x);
86+ });
8387 ext::intel::experimental::matrix::joint_matrix_store (
8488 sg, sub_ar,
8589 pAr + (sg_startx * TM) * K + sg_starty / sg_size * TK, K);
90+ ext::intel::experimental::matrix::joint_matrix_store (
91+ sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
92+ K);
8693 }); // parallel for
8794 }).wait ();
88- return apply_verify<Tc, Ta, M, N>(C, D, A, Ar);
95+ return apply_verify<Tc, M, N>(C, D, Cref) &&
96+ apply_verify<Ta, M, N>(A, Ar, Aref);
8997}
9098
9199template <typename Ta, typename Tc, size_t TM, size_t TN, size_t TK,
@@ -96,16 +104,20 @@ bool test() {
96104 static constexpr size_t K = TK * 2 ;
97105 queue q;
98106
107+ Tc *Cref = malloc_shared<Tc>(M * N, q);
108+ Ta *Aref = malloc_shared<Ta>(M * K, q);
99109 Tc *C = malloc_shared<Tc>(M * N, q);
100110 Tc *D = malloc_shared<Tc>(M * N, q);
101111 Ta *A = malloc_shared<Ta>(M * K, q);
102112 Ta *Ar = malloc_shared<Ta>(M * K, q);
103113
104- matrix_rand (M, N, (Tc *)C, (Tc)100 );
105- matrix_rand (M, K, (Ta *)A, (Ta)100 );
114+ matrix_rand (M, N, (Tc *)Cref, (Tc)100 );
115+ matrix_rand (M, K, (Ta *)Aref, (Ta)100 );
116+ matrix_copy (M, N, Cref, C);
117+ matrix_copy (M, K, Aref, A);
106118
107119 bool res = apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, K, kernel_name>(
108- C, D, A, Ar, q);
120+ C, D, A, Ar, Cref, Aref, q);
109121
110122 if constexpr (std::is_same_v<Ta, bfloat16>)
111123 std::cout << " bfloat16 " << TM << " x" << TN << " x" << TK << " : "
@@ -117,6 +129,8 @@ bool test() {
117129 free (D, q);
118130 free (A, q);
119131 free (Ar, q);
132+ free (Cref, q);
133+ free (Aref, q);
120134
121135 return res;
122136}
0 commit comments