Skip to content

Commit 3565a3e

Browse files
committed
[SYCL][Joint Matrix] Add support for Offset joint_matrix_load and joint_matrix_store overloads
1 parent dce9ab3 commit 3565a3e

File tree

5 files changed

+404
-1
lines changed

5 files changed

+404
-1
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,27 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(
311311
T *Ptr, uint32_t NumRows, uint32_t NumCols, unsigned int CacheLevel,
312312
__spv::MatrixLayout Layout, size_t Stride);
313313

314+
template <typename T, typename Tp, std::size_t R, std::size_t C,
315+
__spv::MatrixUse U,
316+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
317+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
318+
extern __DPCPP_SYCL_EXTERNAL
319+
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
320+
__spirv_CooperativeMatrixLoadOffsetINTEL(T *Ptr, int32_t RowIndex,
321+
int32_t ColIndex,
322+
__spv::MatrixLayout Layout = L,
323+
std::size_t Stride = 0,
324+
int MemOperand = 0);
325+
326+
template <typename T, typename Tp, std::size_t R, std::size_t C,
327+
__spv::MatrixUse U,
328+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
329+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
330+
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreOffsetINTEL(
331+
T *Ptr, int32_t RowIndex, int32_t ColIndex,
332+
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
333+
__spv::MatrixLayout Layout = L, std::size_t Stride = 0, int MemOperand = 0);
334+
314335
#ifndef __SPIRV_BUILTIN_DECLARATIONS__
315336
#error \
316337
"SPIR-V built-ins are not available. Please set -fdeclare-spirv-builtins flag."

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,101 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
11201120
}
11211121
// End out-of-bounds API
11221122

1123+
template <
1124+
typename Group, typename T, typename Tp,
1125+
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
1126+
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
1127+
access::address_space Space, access::decorated IsDecorated,
1128+
std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
1129+
Use == sycl::ext::oneapi::experimental::matrix::use::b,
1130+
bool> = true>
1131+
inline __SYCL_ALWAYS_INLINE void
1132+
joint_matrix_store(Group,
1133+
const sycl::ext::oneapi::experimental::matrix::joint_matrix<
1134+
Group, Tp, Use, NumRows, NumCols, Layout> &Src,
1135+
size_t RowIndex, size_t ColIndex,
1136+
multi_ptr<T, Space, IsDecorated> BaseDst, size_t Stride) {
1137+
#if defined(__SYCL_DEVICE_ONLY__)
1138+
static_assert(Space != access::address_space::private_space,
1139+
"Joint Matrix doesn't support store to private memory!");
1140+
#if defined(__NVPTX__)
1141+
std::ignore = Src;
1142+
std::ignore = BaseDst;
1143+
std::ignore = Stride;
1144+
throw exception(
1145+
make_error_code(errc::runtime),
1146+
"This version of the matrix extension is only currently supported on "
1147+
"intel devices");
1148+
#else
1149+
// intel's impl
1150+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
1151+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(BaseDst);
1152+
__spirv_CooperativeMatrixStoreOffsetINTEL<
1153+
DecorT, Tp, NumRows, NumCols,
1154+
sycl::ext::oneapi::experimental::matrix::spv_matrix_use_traits<
1155+
Use>::value,
1156+
sycl::ext::oneapi::experimental::matrix::spv_matrix_layout_traits<
1157+
Layout>::value>(
1158+
Ptr, RowIndex, ColIndex, Src.spvm,
1159+
sycl::ext::oneapi::experimental::matrix::spv_matrix_layout_traits<
1160+
Layout>::value,
1161+
Stride);
1162+
#endif // defined(__NVPTX__)
1163+
#else
1164+
std::ignore = Src;
1165+
std::ignore = BaseDst;
1166+
std::ignore = Stride;
1167+
throw exception(make_error_code(errc::runtime),
1168+
"joint matrix is not supported on host.");
1169+
#endif // defined(__SYCL_DEVICE_ONLY__)
1170+
}
1171+
template <
1172+
typename Group, typename T, typename Tp,
1173+
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
1174+
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
1175+
typename PropertyListT,
1176+
std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
1177+
Use == sycl::ext::oneapi::experimental::matrix::use::b,
1178+
bool> = true>
1179+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
1180+
Group,
1181+
const sycl::ext::oneapi::experimental::matrix::joint_matrix<
1182+
Group, Tp, Use, NumRows, NumCols, Layout>
1183+
Src,
1184+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> BaseDst,
1185+
size_t RowIndex, size_t ColIndex, size_t Stride) {
1186+
#if defined(__SYCL_DEVICE_ONLY__)
1187+
#if defined(__NVPTX__)
1188+
std::ignore = Src;
1189+
std::ignore = BaseDst;
1190+
std::ignore = Stride;
1191+
throw exception(
1192+
make_error_code(errc::runtime),
1193+
"This version of the matrix extension is only currently supported on "
1194+
"intel devices");
1195+
#else
1196+
// intel's impl
1197+
T *Ptr = BaseDst.get();
1198+
__spirv_CooperativeMatrixStoreOffsetINTEL<
1199+
T, Tp, NumRows, NumCols,
1200+
sycl::ext::oneapi::experimental::matrix::spv_matrix_use_traits<
1201+
Use>::value,
1202+
sycl::ext::oneapi::experimental::matrix::spv_matrix_layout_traits<
1203+
Layout>::value>(
1204+
Ptr, Src.spvm,
1205+
sycl::ext::oneapi::experimental::matrix::spv_matrix_layout_traits<
1206+
Layout>::value,
1207+
RowIndex, ColIndex, Stride);
1208+
#endif // defined(__NVPTX__)
1209+
#else
1210+
std::ignore = Src;
1211+
std::ignore = BaseDst;
1212+
std::ignore = Stride;
1213+
throw exception(make_error_code(errc::runtime),
1214+
"joint matrix is not supported on host.");
1215+
#endif // defined(__SYCL_DEVICE_ONLY__)
1216+
}
1217+
11231218
} // namespace intel::experimental::matrix
11241219

11251220
} // namespace ext

0 commit comments

Comments
 (0)