Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,30 @@ inline __SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(
}
}

#ifdef __SPIRV_USE_COOPERATIVE_MATRIX
template<typename Ta, typename Tb, typename Tc>
constexpr uint32_t CalculateMatrixOperand() {
if constexpr (std::is_same<Ta, sycl::ext::oneapi::bfloat16>::value &&
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value &&
std::is_same<Tc, float>::value)
return static_cast<uint32_t>(
__spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL);
if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
return static_cast<uint32_t>(
__spv::MatrixOperands::MatrixASignedComponentsKHR);
if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
return static_cast<uint32_t>(
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
if constexpr (std::is_signed<Ta>::value && std::is_signed<Tb>::value) {
return static_cast<uint32_t>(
__spv::MatrixOperands::MatrixASignedComponentsKHR) +
static_cast<uint32_t>(
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
}
return 0;
}
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

} // namespace detail
} // namespace _V1
} // namespace sycl
34 changes: 4 additions & 30 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,36 +530,10 @@ joint_matrix_mad(
else
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
#else
if constexpr (std::is_same<Ta, uint16_t>::value &&
std::is_same<Tb, uint16_t>::value &&
std::is_same<Tc, float>::value) {
constexpr uint32_t MatrixOperand = static_cast<uint32_t>(
__spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL);
D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm,
MatrixOperand);
} else if constexpr (std::is_signed<Ta>::value &&
std::is_unsigned<Tb>::value) {
constexpr uint32_t MatrixOperand = static_cast<uint32_t>(
__spv::MatrixOperands::MatrixASignedComponentsKHR);
D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm,
MatrixOperand);
} else if constexpr (std::is_unsigned<Ta>::value &&
std::is_signed<Tb>::value) {
constexpr uint32_t MatrixOperand = static_cast<uint32_t>(
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm,
MatrixOperand);
} else if constexpr (std::is_signed<Ta>::value && std::is_signed<Tb>::value) {
constexpr uint32_t MatrixOperand =
static_cast<uint32_t>(
__spv::MatrixOperands::MatrixASignedComponentsKHR) +
static_cast<uint32_t>(
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm,
MatrixOperand);
} else {
D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm);
}
constexpr uint32_t MatrixOperand =
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc>();
D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm,
MatrixOperand);
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
#endif // defined(__NVPTX__)
#else
Expand Down
Loading