@@ -63,8 +63,7 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
6363 if constexpr (std::is_same_v<Ta, bfloat16> &&
6464 std::is_same_v<Tc, float >)
6565 acc += make_fp32 (va[i]) * make_fp32 (vb[i]);
66- else if constexpr (std::is_same_v<Ta, sycl::half> &&
67- std::is_same_v<Tc, float >)
66+ else if constexpr (std::is_same_v<Ta, sycl::half>)
6867 acc += (float )va[i] * (float )vb[i];
6968 else if constexpr (std::is_same_v<Ta, float > &&
7069 std::is_same_v<Tc, float > ||
@@ -135,7 +134,8 @@ void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) {
135134
136135 for (unsigned int i = 0 ; i < rows; i++) {
137136 for (unsigned int j = 0 ; j < cols; j++) {
138- if constexpr (std::is_same_v<T, bfloat16> || std::is_same_v<T, float > ||
137+ if constexpr (std::is_same_v<T, sycl::half> ||
138+ std::is_same_v<T, bfloat16> || std::is_same_v<T, float > ||
139139 std::is_same_v<T, double >) {
140140 src[i * cols + j] = T (fdistr (dev));
141141 } else if constexpr (std::is_integral_v<T>) {
0 commit comments