|
8 | 8 |
|
9 | 9 | #pragma once |
10 | 10 |
|
11 | | -#include <sycl/access/access.hpp> // for decorated, address_space |
12 | | -#include <sycl/aliases.hpp> // for half, cl_char, cl_double |
13 | | -#include <sycl/detail/helpers.hpp> // for marray |
14 | | -#include <sycl/detail/type_traits.hpp> // for is_gen_based_on_type_s... |
15 | | -#include <sycl/half_type.hpp> // for BIsRepresentationT |
16 | | -#include <sycl/multi_ptr.hpp> // for multi_ptr, address_spa... |
17 | | - |
18 | | -#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16 storage type. |
| 11 | +#include <sycl/access/access.hpp> |
| 12 | +#include <sycl/aliases.hpp> |
| 13 | +#include <sycl/bit_cast.hpp> |
| 14 | +#include <sycl/detail/fwd/half.hpp> |
| 15 | +#include <sycl/detail/type_traits.hpp> |
19 | 16 |
|
20 | 17 | #include <cstddef> // for byte |
21 | 18 | #include <cstdint> // for uint8_t |
|
24 | 21 |
|
25 | 22 | namespace sycl { |
26 | 23 | inline namespace _V1 { |
| 24 | +namespace ext::oneapi { |
| 25 | +class bfloat16; |
| 26 | +} |
27 | 27 | namespace detail { |
28 | 28 | template <typename T> |
29 | 29 | using is_byte = typename |
@@ -166,13 +166,16 @@ template <typename T> auto convertToOpenCLType(T &&x) { |
166 | 166 | static_assert(sizeof(OpenCLType) == sizeof(T)); |
167 | 167 | return static_cast<OpenCLType>(x); |
168 | 168 | } else if constexpr (std::is_same_v<no_ref, half>) { |
169 | | - using OpenCLType = sycl::detail::half_impl::BIsRepresentationT; |
| 169 | + // Make it template-param-dependent to compile with incomplete `half`: |
| 170 | + using OpenCLType = |
| 171 | + std::enable_if_t<std::is_same_v<no_ref, half>, |
| 172 | + sycl::detail::half_impl::BIsRepresentationT>; |
170 | 173 | static_assert(sizeof(OpenCLType) == sizeof(T)); |
171 | 174 | return static_cast<OpenCLType>(x); |
172 | 175 | } else if constexpr (std::is_same_v<no_ref, ext::oneapi::bfloat16>) { |
173 | 176 | // On host, don't interpret BF16 as uint16. |
174 | 177 | #ifdef __SYCL_DEVICE_ONLY__ |
175 | | - using OpenCLType = sycl::ext::oneapi::bfloat16::Bfloat16StorageT; |
| 178 | + using OpenCLType = typename no_ref::Bfloat16StorageT; |
176 | 179 | return sycl::bit_cast<OpenCLType>(x); |
177 | 180 | #else |
178 | 181 | return std::forward<T>(x); |
|
0 commit comments