Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
036e9e6
DRAFT - Add different output and accumulator support in spirv. Add ca…
ggojska Mar 18, 2025
2098325
clang-format
ggojska Jun 17, 2025
b7937b1
Change called BF to F conversion function
ggojska Jun 19, 2025
6e81c69
add Unsupported check for SG32 bfloat test
ggojska Jun 19, 2025
44d232f
remove unneccesary comment
ggojska Jun 24, 2025
a908c46
resolve review comments
ggojska Jun 25, 2025
1bbf8a2
Format patch
ggojska Jun 25, 2025
648dad1
Add unsupported mark in no-unsupported-without-info.cpp
ggojska Jun 25, 2025
c04b373
remove unneccesary new lines
ggojska Jun 26, 2025
dcfa304
Added comment for half casting. Removed redundant code
ggojska Jun 26, 2025
59765a6
fix typo and space at end of comment
ggojska Jun 26, 2025
afa625f
Fix errors with Matrix Operands
ggojska Jun 26, 2025
4a8b914
Update sycl/test-e2e/Matrix/SG32/joint_matrix_half_accumulator.cpp
ggojska Jun 27, 2025
96e05ae
Update joint_matrix_bfloat16_accumulator.cpp
ggojska Jun 27, 2025
a713b7e
Update joint_matrix_half_accumulator.cpp
ggojska Jun 27, 2025
7e74578
Update no-unsupported-without-info.cpp
ggojska Jun 27, 2025
1490375
Change unsupported test count
ggojska Jun 30, 2025
fe8952f
Merge branch 'sycl' into commonhpp_reference_calculation_change
ggojska Jul 4, 2025
ed81d7f
Update no-unsupported-without-info.cpp
ggojska Jul 7, 2025
ab52704
Add packedB layout to default bfloat16 test
ggojska Jul 7, 2025
c9e18e5
Add packedB layout to default half test
ggojska Jul 7, 2025
5872cca
Update joint_matrix_half.cpp
ggojska Jul 8, 2025
6569c21
Update joint_matrix_half.cpp
ggojska Jul 8, 2025
431ae7b
Update joint_matrix_half.cpp
ggojska Jul 8, 2025
eef603c
Change Unsupported to unsupported-intended comment in SG32 test files
ggojska Jul 8, 2025
114411e
format fix
ggojska Jul 8, 2025
b1d7638
Fix unsupported labels
ggojska Jul 8, 2025
f9425a5
Change unsupported test count
ggojska Jul 8, 2025
a82ba68
Add unsupported cpu comment
ggojska Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);

template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
template <typename TA, typename TB, typename TC, typename TD, std::size_t M,
std::size_t K, std::size_t N, __spv::MatrixUse UA,
__spv::MatrixUse UB, __spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
__spv::__spirv_CooperativeMatrixKHR<TD, S, M, N, UC> *
__spirv_CooperativeMatrixMulAddKHR(
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,
Expand Down
34 changes: 34 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#endif // __SYCL_DEVICE_ONLY__
}

operator float() {
#ifdef __SYCL_DEVICE_ONLY__
sycl::ext::oneapi::bfloat16 *ExtractP =
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
union {
uint16_t intStorage;
sycl::ext::oneapi::bfloat16 floatValue;
};
floatValue = *ExtractP;
return __spirv_ConvertBF16ToFINTEL(intStorage);
#else
throw exception(make_error_code(errc::runtime),
"joint matrix is not supported on host.");
#endif // __SYCL_DEVICE_ONLY__
}

explicit operator bool() {
#ifdef __SYCL_DEVICE_ONLY__
sycl::ext::oneapi::bfloat16 *ExtractP =
Expand Down Expand Up @@ -295,6 +314,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#endif // __SYCL_DEVICE_ONLY__
}

wi_element &operator=(const float &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
float *InsertP =
__spirv_AccessChain<float, float, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
*InsertP = rhs;
return *this;
#else
(void)rhs;
throw exception(make_error_code(errc::runtime),
"joint matrix is not supported on host.");
#endif // __SYCL_DEVICE_ONLY__
}

wi_element &operator=(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
NumCols, Use, Layout, Group> &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
Expand Down
30 changes: 15 additions & 15 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,26 @@ extern "C" constexpr __spv::MatrixLayout joint_matrix_layout_to_spv(
}
}

template<typename Ta, typename Tb, typename Tc>
template <typename Ta, typename Tb, typename Tc, typename Td>
constexpr uint32_t CalculateMatrixOperand() {
uint32_t returnValue = 0x00;
if constexpr (std::is_same<Ta, sycl::ext::oneapi::bfloat16>::value &&
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value &&
std::is_same<Tc, float>::value)
return static_cast<uint32_t>(
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL);
if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
return static_cast<uint32_t>(
if constexpr (std::is_same<Tc, sycl::ext::oneapi::bfloat16>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixCBFloat16ComponentsINTEL);
if constexpr (std::is_same<Td, sycl::ext::oneapi::bfloat16>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixResultBFloat16ComponentsINTEL);
if constexpr (std::is_signed<Ta>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixASignedComponentsKHR);
if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
return static_cast<uint32_t>(
if constexpr (std::is_signed<Tb>::value)
returnValue += static_cast<uint32_t>(
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
if constexpr (std::is_signed<Ta>::value && std::is_signed<Tb>::value) {
return static_cast<uint32_t>(
__spv::MatrixOperands::MatrixASignedComponentsKHR) +
static_cast<uint32_t>(
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
}
return 0;
return returnValue;
}

} // namespace detail
Expand Down
24 changes: 18 additions & 6 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,7 @@ template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
sycl::detail::convertTypeToMatrixTypeString<Tc>(),
sycl::detail::convertTypeToMatrixTypeString<Td>(), M, K, N)]]
#endif // defined(__SYCL_DEVICE_ONLY__)
inline __SYCL_ALWAYS_INLINE void
joint_matrix_mad(
inline __SYCL_ALWAYS_INLINE void joint_matrix_mad(
Group,
joint_matrix<Group, Td, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
Expand Down Expand Up @@ -462,9 +461,9 @@ joint_matrix_mad(
}
#else
constexpr uint32_t MatrixOperand =
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc>();
D.spvm =
__spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, MatrixOperand);
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc, Td>();
D.spvm = __spirv_CooperativeMatrixMulAddKHR<Ta, Tb, Tc, Td>(
A.spvm, B.spvm, C.spvm, MatrixOperand);
#endif // defined(__NVPTX__)
#else
std::ignore = A;
Expand All @@ -489,10 +488,23 @@ void joint_matrix_copy(
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T2>::storage_element_type;
using src_storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T1>::storage_element_type;

auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
for (int i = 0; i < wi_data_c.length(); i++) {
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
if constexpr (std::is_same_v<T1, sycl::half>) {
// Special case for SRC type sycl:half since we can't
// cast directly from wi_element(typed half) to other type.
// first cast is from wi_element to half (T1).
// second cast is from half to dst type (T2).
wi_data_dst[i] = static_cast<storage_element_type>(
static_cast<src_storage_element_type>(wi_data_c[i]));
} else {
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
}
}
#endif // defined(__NVPTX__)
#else
Expand Down
17 changes: 12 additions & 5 deletions sycl/test-e2e/Matrix/Inputs/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
for (unsigned int n = 0; n < N; n++) {
int c_ind = transpose_c ? (n * M + m) : m * N + n;
Tc acc = *(C + c_ind);

float tmp = 0.f;
for (unsigned int k = 0; k < K; k++) {
int a_ind = colmajor_a ? (k * M + m) : m * K + k;
int b_ind = colmajor_b ? (n * K + k) : k * N + n;
Expand All @@ -80,6 +80,9 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
acc += make_fp32(va[i]) * make_fp32(vb[i]);
else if constexpr (std::is_same_v<Ta, sycl::half>)
acc += (float)va[i] * (float)vb[i];
else if constexpr (std::is_same_v<Ta, bfloat16> &&
std::is_same_v<Tc, bfloat16>)
tmp += (float)va[i] * (float)vb[i];
else if constexpr (std::is_same_v<Ta, float> &&
std::is_same_v<Tc, float> ||
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
Expand All @@ -92,6 +95,9 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
assert(false && "Unsupported type in matrix_multiply_ref.");
}
}
if constexpr (std::is_same_v<Ta, bfloat16> &&
std::is_same_v<Tc, bfloat16>)
acc += (bfloat16)tmp;

if constexpr (!std::is_same_v<F, std::nullptr_t>) {
lambda(acc);
Expand Down Expand Up @@ -182,10 +188,11 @@ template <typename T1, typename T2, bool exact = false>
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if constexpr (!exact && (std::is_same_v<T1, float> ||
std::is_same_v<T1, bfloat16> ||
(std::is_same_v<T1, double> &&
std::is_same_v<T2, double>))) {
if constexpr (!exact &&
(std::is_same_v<T1, float> ||
std::is_same_v<T1, bfloat16> || std::is_same_v<T1, half> ||
(std::is_same_v<T1, double> &&
std::is_same_v<T2, double>))) {
float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]);
if (diff > FLOAT_EPSILON || std::isnan(src[i * cols + j])) {
std::cerr << "Incorrect result in matrix. "
Expand Down
138 changes: 138 additions & 0 deletions sycl/test-e2e/Matrix/Inputs/joint_matrix_16bit_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
//===---joint_matrix_16bit_impl.hpp - DPC++ joint_matrix----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

template <typename Tab, typename TAcc, typename TResult, size_t TM, size_t TN,
size_t TK, layout B_layout>
class imatrix;

template <typename Tab, typename TAcc, typename TResult, size_t M, size_t N,
size_t K, size_t TM, size_t TN, size_t TK, layout B_layout, size_t VF>
void matrix_multiply(big_matrix<TResult, M, N> &D, big_matrix<TAcc, M, N> &C,
big_matrix<Tab, M, K> &A,
big_matrix<Tab, K / VF, N * VF> &B) {
size_t NDRangeM = M / TM;
size_t NDRangeN = N / TN;
buffer<Tab, 2> bufA(A.get_data(), range<2>(M, K));
buffer<Tab, 2> bufB(B.get_data(), range<2>(K, N));
buffer<TAcc, 2> bufC((TAcc *)C.get_data(), range<2>(M, N));
buffer<TResult, 2> bufD((TResult *)D.get_data(), range<2>(M, N));
queue q;
size_t sg_size =
get_sg_size<imatrix<Tab, TAcc, TResult, TM, TN, TK, B_layout>>(q);

q.submit([&](handler &cgh) {
accessor accA{bufA, cgh};
accessor accB{bufB, cgh};
accessor accC{bufC, cgh};
accessor accD{bufD, cgh};

cgh.parallel_for<imatrix<Tab, TAcc, TResult, TM, TN, TK, B_layout>>(
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
[=](nd_item<2> spmd_item)
#ifdef SG_SZ
[[sycl::reqd_sub_group_size(SG_SZ)]]
#endif
{
// The submatrix API has to be accessed by all the workitems in a
// subgroup these functions will be called once by the subgroup no
// code divergence between the workitems
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, Tab, use::a, TM, TK, layout::row_major>
sub_a;
joint_matrix<sub_group, Tab, use::b, TK, TN, B_layout> sub_b;
joint_matrix<sub_group, TAcc, use::accumulator, TM, TN> sub_c;
joint_matrix<sub_group, TResult, use::accumulator, TM, TN> sub_d;

joint_matrix_load(
sg, sub_c,
accC.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);

for (int k = 0; k < K / TK; k += 1) {
joint_matrix_load(
sg, sub_a,
accA.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * K + k * TK,
K);
joint_matrix_load(
sg, sub_b,
accB.template get_multi_ptr<access::decorated::no>() +
(k * TK / VF) * (N * VF) + sg_starty / sg_size * TN * VF,
N * VF);

joint_matrix_mad(sg, sub_d, sub_a, sub_b, sub_c);
joint_matrix_copy(sg, sub_d, sub_c);
}

joint_matrix_store(
sg, sub_d,
accD.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);
}); // parallel for
}).wait();
}

template <typename Tab, typename TAcc, typename TResult, size_t TM, size_t TN,
size_t TK, layout B_layout, size_t VF>
void test() {
std::cout << "Testing: " << TM << " x " << TN << " x " << TK
<< " [TM x TN x TK]" << std::endl;

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
Tab A[MATRIX_M][MATRIX_K];
Tab B[MATRIX_K / VF][MATRIX_N * VF];
TAcc C[MATRIX_M][MATRIX_N];
TResult D[MATRIX_M][MATRIX_N];
TResult DRef[MATRIX_M][MATRIX_N];

matrix_rand<Tab>(MATRIX_M, MATRIX_K, (Tab *)A, Tab(1));
matrix_rand<Tab>(MATRIX_K / VF, MATRIX_N * VF, (Tab *)B, Tab(1));

matrix_fill(MATRIX_M, MATRIX_N, (TAcc *)C, TAcc(1));
matrix_fill(MATRIX_M, MATRIX_N, (TResult *)D, TResult(1));
matrix_fill(MATRIX_M, MATRIX_N, (TResult *)DRef, TResult(1));

big_matrix<TAcc, MATRIX_M, MATRIX_N> MC((TAcc *)&C);
big_matrix<TResult, MATRIX_M, MATRIX_N> MD((TResult *)&D);
big_matrix<Tab, MATRIX_M, MATRIX_K> MA((Tab *)&A);
big_matrix<Tab, MATRIX_K / VF, MATRIX_N * VF> MB((Tab *)&B);

matrix_multiply<Tab, TAcc, TResult, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK,
B_layout, VF>(MD, MC, MA, MB);
matrix_multiply_ref<Tab, Tab, TResult, VF>(
(Tab *)A, (Tab *)B, (TResult *)DRef, MATRIX_M, MATRIX_N, MATRIX_K / VF);
assert(matrix_compare(MATRIX_M, MATRIX_N, (TResult *)D, (TResult *)DRef));
}

template <typename TLow, typename THigh, size_t TM, size_t TN, size_t TK,
layout B_layout, size_t VF>
void test_combo() {
test<TLow, TLow, THigh, TM, TN, TK, B_layout, VF>();
test<TLow, THigh, TLow, TM, TN, TK, B_layout, VF>();
test<TLow, TLow, TLow, TM, TN, TK, B_layout, VF>();
test<TLow, THigh, THigh, TM, TN, TK, B_layout, VF>();
}

template <typename TLow, typename THigh, layout B_layout, size_t VF>
void test_all() {
test_combo<TLow, THigh, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16, B_layout, VF>();
test_combo<TLow, THigh, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32, B_layout, VF>();
}
Loading