diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index ec5f08a6c809f..9d2759bdd3ad5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -82,6 +82,30 @@ inline __SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv( } } +#ifdef __SPIRV_USE_COOPERATIVE_MATRIX +template +constexpr uint32_t CalculateMatrixOperand() { + if constexpr (std::is_same::value && + std::is_same::value && + std::is_same::value) + return static_cast( + __spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL); + if constexpr (std::is_signed::value && std::is_unsigned::value) + return static_cast( + __spv::MatrixOperands::MatrixASignedComponentsKHR); + if constexpr (std::is_unsigned::value && std::is_signed::value) + return static_cast( + __spv::MatrixOperands::MatrixBSignedComponentsKHR); + if constexpr (std::is_signed::value && std::is_signed::value) { + return static_cast( + __spv::MatrixOperands::MatrixASignedComponentsKHR) + + static_cast( + __spv::MatrixOperands::MatrixBSignedComponentsKHR); + } + return 0; +} +#endif // __SPIRV_USE_COOPERATIVE_MATRIX + } // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index d3d57f24c56e6..c8d2918b6b105 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -530,36 +530,10 @@ joint_matrix_mad( else D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); #else - if constexpr (std::is_same::value && - std::is_same::value && - std::is_same::value) { - constexpr uint32_t MatrixOperand = static_cast( - __spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL); - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, - MatrixOperand); - } else if constexpr (std::is_signed::value && - std::is_unsigned::value) { - constexpr uint32_t MatrixOperand = static_cast( - __spv::MatrixOperands::MatrixASignedComponentsKHR); - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, - MatrixOperand); - } else if constexpr (std::is_unsigned::value && - std::is_signed::value) { - constexpr uint32_t MatrixOperand = static_cast( - __spv::MatrixOperands::MatrixBSignedComponentsKHR); - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, - MatrixOperand); - } else if constexpr (std::is_signed::value && std::is_signed::value) { - constexpr uint32_t MatrixOperand = - static_cast( - __spv::MatrixOperands::MatrixASignedComponentsKHR) + - static_cast( - __spv::MatrixOperands::MatrixBSignedComponentsKHR); - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, - MatrixOperand); - } else { - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm); - } + constexpr uint32_t MatrixOperand = + sycl::detail::CalculateMatrixOperand(); + D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, + MatrixOperand); #endif // __SPIRV_USE_COOPERATIVE_MATRIX #endif // defined(__NVPTX__) #else