Skip to content

Commit 1ed8c7c

Browse files
committed
Fix and add back SFINAE
1 parent bab2b3b commit 1ed8c7c

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

ggml/src/ggml-sycl/set_rows.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
#include "set_rows.hpp"
22

3+
namespace utils {
4+
template<typename T>
5+
static constexpr bool is_arithmetic_v() {
6+
return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
7+
}
8+
}
39
template<typename TIn, typename TOut>
4-
static inline void convert(const char* src, char* dst) {
10+
static inline std::enable_if_t<utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void>
11+
convert (const char* src, char* dst) {
512
auto src_val = *reinterpret_cast<const TIn*>(src);
6-
auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut>()[0];
7-
*reinterpret_cast<TOut*>(dst) = dst_val;
13+
auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut, sycl::rounding_mode::automatic>()[0];
14+
*reinterpret_cast<TOut*>(dst) = dst_val;;
815
}
916

1017
template<typename TIn, typename TOut>

0 commit comments

Comments
 (0)