Skip to content

Commit 19b1fb4

Browse files
committed
[SYCL][Matrix]Add support for missing matrix combinations for half and bfloat16 types
1 parent ac71f61 commit 19b1fb4

File tree

4 files changed

+122
-16
lines changed

4 files changed

+122
-16
lines changed

sycl/source/detail/device_info.hpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,14 +849,96 @@ struct get_device_info_impl<
849849
matrix_type::sint32, matrix_type::sint32},
850850
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
851851
matrix_type::fp32, matrix_type::fp32},
852+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
853+
matrix_type::fp16, matrix_type::fp32},
854+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
855+
matrix_type::fp32, matrix_type::fp16},
856+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
857+
matrix_type::fp16, matrix_type::fp16},
858+
{0, 0, 0, 16, 16, 16, matrix_type::fp16, matrix_type::fp16,
859+
matrix_type::fp32, matrix_type::fp16},
860+
{0, 0, 0, 16, 16, 16, matrix_type::fp16, matrix_type::fp16,
861+
matrix_type::fp16, matrix_type::fp16},
862+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
863+
matrix_type::fp32, matrix_type::fp32},
864+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
865+
matrix_type::fp16, matrix_type::fp32},
866+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
867+
matrix_type::fp32, matrix_type::fp16},
868+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
869+
matrix_type::fp16, matrix_type::fp16},
870+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
871+
matrix_type::fp32, matrix_type::fp32},
872+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
873+
matrix_type::fp16, matrix_type::fp32},
874+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
875+
matrix_type::fp32, matrix_type::bf16},
876+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
877+
matrix_type::fp16, matrix_type::fp16},
878+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
879+
matrix_type::fp32, matrix_type::fp32},
880+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
881+
matrix_type::fp16, matrix_type::fp32},
882+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
883+
matrix_type::fp32, matrix_type::fp16},
884+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
885+
matrix_type::fp16, matrix_type::fp16},
886+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
887+
matrix_type::fp32, matrix_type::fp32},
888+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
889+
matrix_type::fp16, matrix_type::fp32},
890+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
891+
matrix_type::fp32, matrix_type::fp16},
892+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
893+
matrix_type::fp16, matrix_type::fp16},
894+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
895+
matrix_type::bf16, matrix_type::bf16},
896+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
897+
matrix_type::fp32, matrix_type::bf16},
898+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
899+
matrix_type::bf16, matrix_type::fp32},
852900
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
853901
matrix_type::fp32, matrix_type::fp32},
854902
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
855903
matrix_type::fp32, matrix_type::fp32},
904+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
905+
matrix_type::bf16, matrix_type::fp32},
906+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
907+
matrix_type::fp32, matrix_type::bf16},
908+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
909+
matrix_type::bf16, matrix_type::bf16},
856910
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
857911
matrix_type::fp32, matrix_type::fp32},
912+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
913+
matrix_type::bf16, matrix_type::fp32},
914+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
915+
matrix_type::fp32, matrix_type::bf16},
916+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
917+
matrix_type::bf16, matrix_type::bf16},
858918
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
859919
matrix_type::fp32, matrix_type::fp32},
920+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
921+
matrix_type::bf16, matrix_type::fp32},
922+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
923+
matrix_type::fp32, matrix_type::bf16},
924+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
925+
matrix_type::bf16, matrix_type::bf16},
926+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
927+
matrix_type::fp32, matrix_type::fp32},
928+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
929+
matrix_type::bf16, matrix_type::fp32},
930+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
931+
matrix_type::fp32, matrix_type::bf16},
932+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
933+
matrix_type::bf16, matrix_type::bf16},
934+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
935+
matrix_type::fp32, matrix_type::fp32},
936+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
937+
matrix_type::bf16, matrix_type::fp32},
938+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
939+
matrix_type::fp32, matrix_type::bf16},
940+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
941+
matrix_type::bf16, matrix_type::bf16},
860942
{8, 0, 0, 0, 16, 8, matrix_type::tf32, matrix_type::tf32,
861943
matrix_type::fp32, matrix_type::fp32},
862944
};

sycl/test-e2e/Matrix/common.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,17 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
6363
if constexpr (std::is_same_v<Ta, bfloat16> &&
6464
std::is_same_v<Tc, float>)
6565
acc += make_fp32(va[i]) * make_fp32(vb[i]);
66+
else if constexpr (std::is_same_v<Ta, sycl::half> &&
67+
std::is_same_v<Tc, float>)
68+
acc += (float)va[i] * (float)vb[i];
6669
else if constexpr (std::is_same_v<Ta, float> &&
6770
std::is_same_v<Tc, float> ||
6871
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
72+
(std::is_same_v<Ta, bfloat16> ||
73+
std::is_same_v<Ta, sycl::half>) ||
6974
(std::is_same_v<Ta, double> &&
7075
std::is_same_v<Tc, double>))
7176
acc += va[i] * vb[i];
72-
else if constexpr (std::is_same_v<Ta, sycl::half> &&
73-
std::is_same_v<Tc, float>)
74-
acc += (float)va[i] * (float)vb[i];
7577
else
7678
assert(false && "Unsupported type in matrix_multiply_ref.");
7779
}

sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
1414
big_matrix<T2, K / 2, N * 2> &B) {
1515
size_t NDRangeM = M / TM;
1616
size_t NDRangeN = N / TN;
17-
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
18-
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
19-
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
17+
buffer<T2, 2> bufA(A.get_data(), range<2>(M, K));
18+
buffer<T2, 2> bufB(B.get_data(), range<2>(K, N));
19+
buffer<T1, 2> bufC((T1 *)C.get_data(), range<2>(M, N));
2020

2121
queue q;
2222
size_t sg_size = get_sg_size<imatrix<T1, TM, TN, TK>>(q);
2323
q.submit([&](handler &cgh) {
24-
auto accC = bufC.get_access<access::mode::read_write>(cgh);
25-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
26-
auto accB = bufB.get_access<access::mode::read_write>(cgh);
24+
accessor accA{bufA, cgh};
25+
accessor accB{bufB, cgh};
26+
accessor accC{bufC, cgh};
2727

2828
cgh.parallel_for<imatrix<T1, TM, TN, TK>>(
2929
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
@@ -41,13 +41,11 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
4141
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
4242

4343
sub_group sg = spmd_item.get_sub_group();
44-
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
45-
sub_a;
44+
joint_matrix<sub_group, T2, use::a, TM, TK, layout::row_major> sub_a;
4645
// For B, we assume B has been already VNNIed.
47-
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
48-
layout::ext_intel_packed>
46+
joint_matrix<sub_group, T2, use::b, TK, TN, layout::ext_intel_packed>
4947
sub_b;
50-
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
48+
joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;
5149

5250
joint_matrix_load(
5351
sg, sub_c,
@@ -122,13 +120,21 @@ int main() {
122120

123121
if (combinations[i].nsize == 16) { // architecture::intel_gpu_pvc
124122
test<bfloat16, float, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
123+
// test<bfloat16, bfloat16, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
125124

126125
// This combination is not currently supported for sub group size = 32 in
127126
// IGC
128127
#if (!defined(SG_SZ) || SG_SZ != 32)
129128
test<bfloat16, float, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
129+
// test<bfloat16, bfloat16, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
130130
test<bfloat16, float, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
131-
test<bfloat16, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
131+
// test<bfloat16, bfloat16, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
132+
test<bfloat16, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
133+
// test<bfloat16, bfloat16, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
134+
// test<bfloat16, float, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
135+
// test<bfloat16, bfloat16, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
136+
// test<bfloat16, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
137+
// test<bfloat16, bfloat16, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
132138
#endif
133139
break;
134140
}

sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void matrix_multiply(big_matrix<TResult, M, N> &C, big_matrix<T, M, K> &A,
2424
accessor accB{bufB, cgh};
2525
accessor accC{bufC, cgh};
2626

27-
cgh.parallel_for<mult<T, TM, TN, TK>>(
27+
cgh.parallel_for<mult<TResult, TM, TN, TK>>(
2828
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, sg_size}),
2929
[=](nd_item<2> spmd_item)
3030
#ifdef SG_SZ
@@ -122,6 +122,22 @@ int main() {
122122

123123
if (combinations[i].nsize == 16) { // architecture::intel_gpu_pvc
124124
test<float, half, 2, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
125+
// test<half, half, 2, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
126+
127+
// This combination is not currently supported for sub group size = 32 in
128+
// IGC
129+
#if (!defined(SG_SZ) || SG_SZ != 32)
130+
test<half, float, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
131+
// test<half, half, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
132+
test<half, float, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
133+
// test<half, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
134+
// test<half, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
135+
// test<half, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
136+
// test<half, float, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
137+
// test<half, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
138+
// test<half, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
139+
// test<half, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
140+
#endif
125141
break;
126142
}
127143

0 commit comments

Comments
 (0)