1- // ==-joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp - DPC++ joint_matrix-==//
1+ // ==-joint_matrix_16bit_colmajorA_colmajorB.cpp - DPC++ joint_matrix-==//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
88
9+ // This tests support of col major layout for matrix B which does transpose and
10+ // then VNNI transform. This is currently only available on AMX
11+
12+ // REQUIRES: aspect-ext_intel_matrix
13+
14+ // RUN: %{build} -o %t.out
15+ // RUN: %{run} %t.out
16+ // RUN: %{build} -o %t32.out -DSG_SZ=32
17+ // RUN: %{run} %t32.out
18+
19+ // XFAIL: gpu
20+ // XFAIL-TRACKER: GSD-5768
21+
22+ #include " common.hpp"
23+
924constexpr size_t TM = 8 ;
1025constexpr size_t TN = 16 ;
1126constexpr size_t TK = 16 ;
1227
28+ template <typename T> class imatrix ;
29+
1330template <typename T1, typename T2, size_t M, size_t N, size_t K>
1431void matrix_multiply (big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
1532 big_matrix<T2, K, N> &B) {
1633 size_t NDRangeM = M / TM;
1734 size_t NDRangeN = N / TN;
18- buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, K));
19- buffer<bfloat16 , 2 > bufB (B.get_data (), range<2 >(K, N));
35+ buffer<T2 , 2 > bufA (A.get_data (), range<2 >(M, K));
36+ buffer<T2 , 2 > bufB (B.get_data (), range<2 >(K, N));
2037 buffer<float , 2 > bufC ((float *)C.get_data (), range<2 >(M, N));
2138
2239 queue q;
23- size_t sg_size = get_sg_size<class imatrix >(q);
40+ size_t sg_size = get_sg_size<class imatrix <T2>>(q);
41+ std::cout << " subgroup size " << sg_size << " " ;
42+
2443 q.submit ([&](handler &cgh) {
2544 auto accC = bufC.get_access <access::mode::read_write>(cgh);
26- auto accA = bufA.get_access <access::mode::read_write>(cgh);
27- auto accB = bufB.get_access <access::mode::read_write>(cgh);
45+ auto accA = bufA.template get_access <access::mode::read_write>(cgh);
46+ auto accB = bufB.template get_access <access::mode::read_write>(cgh);
2847
29- cgh.parallel_for <class imatrix >(
48+ cgh.parallel_for <class imatrix <T2> >(
3049 nd_range<2 >({NDRangeM, NDRangeN * sg_size}, {1 , 1 * sg_size}),
3150 [=](nd_item<2 > spmd_item)
3251#ifdef SG_SZ
@@ -42,10 +61,8 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
4261 const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
4362
4463 sub_group sg = spmd_item.get_sub_group ();
45- joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::col_major>
46- sub_a;
47- joint_matrix<sub_group, bfloat16, use::b, TK, TN, layout::col_major>
48- sub_b;
64+ joint_matrix<sub_group, T2, use::a, TM, TK, layout::col_major> sub_a;
65+ joint_matrix<sub_group, T2, use::b, TK, TN, layout::col_major> sub_b;
4966 joint_matrix<sub_group, float , use::accumulator, TM, TN> sub_c;
5067
5168 joint_matrix_load (
@@ -75,31 +92,57 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
7592 }).wait ();
7693}
7794
78- int main () {
95+ template < typename T> void test () {
7996 static constexpr size_t MATRIX_M = TM * 2 ;
8097 static constexpr size_t MATRIX_N = TN * 2 ;
8198 static constexpr size_t MATRIX_K = TK * 2 ;
82- bfloat16 A[MATRIX_K][MATRIX_M];
83- bfloat16 B[MATRIX_N][MATRIX_K];
99+ T A[MATRIX_K][MATRIX_M];
100+ T B[MATRIX_N][MATRIX_K];
84101 float C[MATRIX_M][MATRIX_N];
85102 float D[MATRIX_M][MATRIX_N];
86103
87- matrix_fill (MATRIX_K, MATRIX_M, (bfloat16 *)A,
104+ matrix_fill (MATRIX_K, MATRIX_M, (T *)A,
88105 [](int i, int j) { return 1 .0f * (i + j); });
89- matrix_fill (MATRIX_N, MATRIX_K, (bfloat16 *)B,
106+ matrix_fill (MATRIX_N, MATRIX_K, (T *)B,
90107 [](int i, int j) { return 2 .0f * i + 3 .0f * j; });
91108 matrix_fill (MATRIX_M, MATRIX_N, (float *)C, 1 .0f );
92109 matrix_fill (MATRIX_M, MATRIX_N, (float *)D, 1 .0f );
93110
94111 big_matrix<float , MATRIX_M, MATRIX_N> MC ((float *)&C);
95112 big_matrix<float , MATRIX_M, MATRIX_N> MD ((float *)&D);
96- big_matrix<bfloat16 , MATRIX_M, MATRIX_K> MA ((bfloat16 *)&A);
97- big_matrix<bfloat16 , MATRIX_K, MATRIX_N> MB ((bfloat16 *)&B);
113+ big_matrix<T , MATRIX_M, MATRIX_K> MA ((T *)&A);
114+ big_matrix<T , MATRIX_K, MATRIX_N> MB ((T *)&B);
98115 matrix_multiply (MC, MA, MB);
99- matrix_multiply_ref ((bfloat16 *)A, (bfloat16 *)B, (float *)D, MATRIX_M,
100- MATRIX_N, MATRIX_K, false , true , true );
116+ matrix_multiply_ref ((T *)A, (T *)B, (float *)D, MATRIX_M, MATRIX_N, MATRIX_K,
117+ false , true , true );
118+
119+ assert (matrix_compare (MATRIX_M, MATRIX_N, (float *)C, (float *)D));
120+ std::cout << " passed" << std::endl;
121+ }
122+
123+ int main () {
124+ queue q;
125+ std::vector<combination> combinations =
126+ q.get_device ().get_info <syclex::info::device::matrix_combinations>();
127+ bool bf16_run = false ;
128+ bool half_run = false ;
129+
130+ for (auto &combination : combinations) {
131+ if (!bf16_run && combination.atype == matrix_type::bf16 ) {
132+ std::cout << " bf16 " ;
133+ test<bfloat16>();
134+ bf16_run = true ;
135+ }
136+
137+ if (!half_run && combination.atype == matrix_type::fp16) {
138+ std::cout << " half " ;
139+ test<half>();
140+ half_run = true ;
141+ }
142+
143+ if (bf16_run && half_run)
144+ break ;
145+ }
101146
102- bool res = matrix_compare (MATRIX_M, MATRIX_N, (float *)C, (float *)D);
103- std::cout << (res ? " passed" : " failed" ) << std::endl;
104- return !res;
147+ return 0 ;
105148}
0 commit comments