Skip to content

Commit 19c68ba

Browse files
committed
[SYCL][Matrix] Fix bfloat16 component type matrix muladd
Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent af8361c commit 19c68ba

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ joint_matrix_mad(
530530
else
531531
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
532532
#else
533-
if constexpr (std::is_same<Ta, uint16_t>::value &&
534-
std::is_same<Tb, uint16_t>::value &&
533+
if constexpr (std::is_same<Ta, sycl::ext::oneapi::bfloat16>::value &&
534+
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value &&
535535
std::is_same<Tc, float>::value) {
536536
constexpr uint32_t MatrixOperand = static_cast<uint32_t>(
537537
__spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL);

0 commit comments

Comments
 (0)