Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,35 +118,39 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
return;
}

template <typename Group, typename T, use Use, size_t M, size_t N,
template <typename Group, typename T0, typename T1, use Use, size_t M, size_t N,
layout Layout, typename F>
inline __SYCL_ALWAYS_INLINE void
joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jmsrc,
joint_matrix<Group, T, Use, M, N, Layout> &jmdest,
joint_matrix_apply(Group sg, joint_matrix<Group, T0, Use, M, N, Layout> &jm0,
joint_matrix<Group, T1, Use, M, N, Layout> &jm1,
F &&lambda) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
std::ignore = sg;
for (int i = 0; i < jmsrc.matrix_impl.wi_marray.size(); i++) {
lambda(jmsrc.matrix_impl.wi_marray[i], jmdest.matrix_impl.wi_marray[i]);
for (int i = 0; i < jm0.matrix_impl.wi_marray.size(); i++) {
lambda(jm0.matrix_impl.wi_marray[i], jm1.matrix_impl.wi_marray[i]);
}
#else // NVPTX
using storage_element_type =
using storage_element_type0 =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T>::storage_element_type;
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jmsrc);
auto wi_data_d = sycl::ext::oneapi::detail::get_wi_data(sg, jmdest);
for (int i = 0; i < wi_data_c.length(); i++) {
storage_element_type elementsrc = wi_data_c[i];
storage_element_type elementdest = wi_data_d[i];
lambda(elementsrc, elementdest);
wi_data_d[i] = elementdest;
T0>::storage_element_type;
using storage_element_type1 =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T1>::storage_element_type;
auto wi_data_0 = sycl::ext::oneapi::detail::get_wi_data(sg, jm0);
auto wi_data_1 = sycl::ext::oneapi::detail::get_wi_data(sg, jm1);
for (int i = 0; i < wi_data_0.length(); i++) {
storage_element_type0 element0 = wi_data_0[i];
storage_element_type1 element1 = wi_data_1[i];
lambda(element0, element1);
wi_data_0[i] = element0;
wi_data_1[i] = element1;
}
#endif
#else
std::ignore = sg;
std::ignore = jmsrc;
std::ignore = jmdest;
std::ignore = jm0;
std::ignore = jm1;
std::ignore = lambda;
throw exception(make_error_code(errc::runtime),
"joint matrix is not supported on host.");
Expand Down
9 changes: 8 additions & 1 deletion sycl/test-e2e/Matrix/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) {
}
}

template <typename F, typename T>
void matrix_apply(unsigned int rows, unsigned int cols, T *mat, F op) {
for (unsigned int i = 0; i < rows; i++)
for (unsigned int j = 0; j < cols; j++)
mat[i * cols + j] = op(mat[i * cols + j]);
}

template <typename T1, typename T2, bool exact = false>
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
for (int i = 0; i < rows; i++) {
Expand All @@ -173,7 +180,7 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
<< ", Epsilon: " << FLOAT_EPSILON << "\n";
return false;
}
} else if constexpr (exact || std::is_same_v<T1, int32_t>) {
} else if constexpr (exact || std::is_integral_v<T1>) {
if (src[i * cols + j] != ref[i * cols + j]) {
std::cout << "Incorrect result in matrix."
<< "i: " << i << ", j: " << j
Expand Down
70 changes: 42 additions & 28 deletions sycl/test-e2e/Matrix/joint_matrix_apply_two_matrices_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,26 @@
//===----------------------------------------------------------------------===//
#include <sycl/usm.hpp>

template <typename Tc, typename Ta, size_t M, size_t N>
bool apply_verify(Tc *C, Tc *D, Ta *A, Ta *Ar) {
for (size_t i = 0; i < M; i++)
for (size_t j = 0; j < N; j++) {
Tc diffc = D[i * N + j] - C[i * N + j] * 2;
Ta diffa = Ar[i * N + j] - (A[i * N + j] + 42);
if constexpr (std::is_same_v<Ta, bfloat16>) {
if (std::fabs(diffc) > FLOAT_EPSILON ||
std::fabs(diffa) > FLOAT_EPSILON || std::isnan(C[i * N + j]) ||
std::isnan(A[i * N + j])) {
return false;
}
} else {
if (std::abs(diffc) > 0 || std::abs(diffa) > 0) {
return false;
}
}
}
return true;
template <typename T> T mul2(T x) { return x * 2; }

template <typename T> T add5(T x) { return x + 5; }

template <typename Tc, size_t M, size_t N>
bool apply_verify(Tc *C, Tc *D, Tc *ref) {
Tc *refcopy = (Tc *)std::malloc(M * N * sizeof(Tc));
memcpy(refcopy, ref, M * N * sizeof(Tc));
matrix_apply(M, N, ref, mul2<Tc>);
bool res = matrix_compare(M, N, D, ref);

matrix_apply(M, N, refcopy, add5<Tc>);
res &= matrix_compare(M, N, C, refcopy);
return res;
}

template <typename Tc, typename Ta, size_t TM, size_t TN, size_t TK, size_t M,
size_t N, size_t K, class kernel_name>
bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, Tc *Cref, Ta *Aref,
queue q) {
size_t NDRangeM = M / TM;
size_t NDRangeN = N / TN;

Expand Down Expand Up @@ -70,22 +67,33 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
joint_matrix_load(
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);
joint_matrix_apply(sg, sub_c, sub_d,
[](const Tc &x, Tc &y) { y = x * 2; });
joint_matrix_apply(sg, sub_c, sub_d, [](Tc &x, Tc &y) {
y = mul2(x);
x = add5(x);
});
joint_matrix_store(
sg, sub_d, pD + (sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);
joint_matrix_store(
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);
joint_matrix_load(
sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
K);
joint_matrix_apply(sg, sub_a, sub_ar,
[](const Ta &x, Ta &y) { y = x + 42; });
joint_matrix_apply(sg, sub_a, sub_ar, [](Ta &x, Ta &y) {
y = mul2(x);
x = add5(x);
});
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_ar,
pAr + (sg_startx * TM) * K + sg_starty / sg_size * TK, K);
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
K);
}); // parallel for
}).wait();
return apply_verify<Tc, Ta, M, N>(C, D, A, Ar);
return apply_verify<Tc, M, N>(C, D, Cref) &&
apply_verify<Ta, M, N>(A, Ar, Aref);
}

template <typename Ta, typename Tc, size_t TM, size_t TN, size_t TK,
Expand All @@ -96,16 +104,20 @@ bool test() {
static constexpr size_t K = TK * 2;
queue q;

Tc *Cref = malloc_shared<Tc>(M * N, q);
Ta *Aref = malloc_shared<Ta>(M * K, q);
Tc *C = malloc_shared<Tc>(M * N, q);
Tc *D = malloc_shared<Tc>(M * N, q);
Ta *A = malloc_shared<Ta>(M * K, q);
Ta *Ar = malloc_shared<Ta>(M * K, q);

matrix_rand(M, N, (Tc *)C, (Tc)100);
matrix_rand(M, K, (Ta *)A, (Ta)100);
matrix_rand(M, N, (Tc *)Cref, (Tc)100);
matrix_rand(M, K, (Ta *)Aref, (Ta)100);
matrix_copy(M, N, Cref, C);
matrix_copy(M, K, Aref, A);

bool res = apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, K, kernel_name>(
C, D, A, Ar, q);
C, D, A, Ar, Cref, Aref, q);

if constexpr (std::is_same_v<Ta, bfloat16>)
std::cout << "bfloat16 " << TM << "x" << TN << "x" << TK << ": "
Expand All @@ -117,6 +129,8 @@ bool test() {
free(D, q);
free(A, q);
free(Ar, q);
free(Cref, q);
free(Aref, q);

return res;
}
Expand Down
Loading