88
99#pragma once
1010
11- #include < sycl/aliases.hpp> // for half, cl_char, cl_int
12- #include < sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
13- #include < sycl/detail/type_traits.hpp> // for is_floating_point
14-
15- #include < sycl/ext/oneapi/bfloat16.hpp> // bfloat16
16-
17- #include < cstddef>
18- #include < type_traits> // for enable_if_t, is_same
11+ #include < sycl/aliases.hpp>
12+ #include < sycl/detail/generic_type_traits.hpp>
13+ #include < sycl/detail/type_traits.hpp>
14+ #include < sycl/detail/type_traits/vec_marray_traits.hpp>
15+ #include < sycl/ext/oneapi/bfloat16.hpp>
1916
2017namespace sycl {
2118inline namespace _V1 {
@@ -50,13 +47,7 @@ struct UnaryPlus {
5047};
5148
5249struct VecOperators {
53- #ifdef __SYCL_DEVICE_ONLY__
54- static constexpr bool is_host = false ;
55- #else
56- static constexpr bool is_host = true ;
57- #endif
58-
59- template <typename BinOp, typename ... ArgTys>
50+ template <typename OpTy, typename ... ArgTys>
6051 static constexpr auto apply (const ArgTys &...Args) {
6152 using Self = nth_type_t <0 , ArgTys...>;
6253 static_assert (is_vec_v<Self>);
@@ -65,88 +56,99 @@ struct VecOperators {
6556 using element_type = typename Self::element_type;
6657 constexpr int N = Self::size ();
6758 constexpr bool is_logical = check_type_in_v<
68- BinOp , std::equal_to<void >, std::not_equal_to<void >, std::less<void >,
59+ OpTy , std::equal_to<void >, std::not_equal_to<void >, std::less<void >,
6960 std::greater<void >, std::less_equal<void >, std::greater_equal<void >,
7061 std::logical_and<void >, std::logical_or<void >, std::logical_not<void >>;
7162
7263 using result_t = std::conditional_t <
7364 is_logical, vec<fixed_width_signed<sizeof (element_type)>, N>, Self>;
7465
75- BinOp Op{};
76- if constexpr (is_host || N == 1 ||
77- std::is_same_v<element_type, ext::oneapi::bfloat16>) {
78- result_t res{};
79- for (size_t i = 0 ; i < N; ++i)
80- if constexpr (is_logical)
81- res[i] = Op (Args[i]...) ? -1 : 0 ;
82- else
83- res[i] = Op (Args[i]...);
84- return res;
85- } else {
86- using vector_t = typename Self::vector_t ;
87-
88- auto res = [&](auto ... xs) {
66+ OpTy Op{};
67+ #ifdef __has_extension
68+ #if __has_extension(attribute_ext_vector_type)
69+ // ext_vector_type's bool vectors are mapped onto <N x i1> and have
70+ // different memory layout than sycl::vec<bool ,N> (which has 1 byte per
71+ // element). As such we perform operation on int8_t and then need to
72+ // create bit pattern that can be bit-casted back to the original
73+ // sycl::vec<bool, N>. This is a hack actually, but we've been doing
74+ // that for a long time using sycl::vec::vector_t type.
75+ using vec_elem_ty =
76+ typename detail::map_type<element_type, //
77+ bool , /* ->*/ std::int8_t ,
78+ #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
79+ std::byte, /* ->*/ std::uint8_t ,
80+ #endif
81+ #ifdef __SYCL_DEVICE_ONLY__
82+ half, /* ->*/ _Float16,
83+ #endif
84+ element_type, /* ->*/ element_type>::type;
85+ if constexpr (N != 1 &&
86+ detail::is_valid_type_for_ext_vector_v<vec_elem_ty>) {
87+ using vec_t = ext_vector<vec_elem_ty, N>;
88+ auto tmp = [&](auto ... xs) {
8989 // Workaround for https://github.com/llvm/llvm-project/issues/119617.
9090 if constexpr (sizeof ...(Args) == 2 ) {
9191 return [&](auto x, auto y) {
92- if constexpr (std::is_same_v<BinOp , std::equal_to<void >>)
92+ if constexpr (std::is_same_v<OpTy , std::equal_to<void >>)
9393 return x == y;
94- else if constexpr (std::is_same_v<BinOp , std::not_equal_to<void >>)
94+ else if constexpr (std::is_same_v<OpTy , std::not_equal_to<void >>)
9595 return x != y;
96- else if constexpr (std::is_same_v<BinOp , std::less<void >>)
96+ else if constexpr (std::is_same_v<OpTy , std::less<void >>)
9797 return x < y;
98- else if constexpr (std::is_same_v<BinOp , std::less_equal<void >>)
98+ else if constexpr (std::is_same_v<OpTy , std::less_equal<void >>)
9999 return x <= y;
100- else if constexpr (std::is_same_v<BinOp , std::greater<void >>)
100+ else if constexpr (std::is_same_v<OpTy , std::greater<void >>)
101101 return x > y;
102- else if constexpr (std::is_same_v<BinOp , std::greater_equal<void >>)
102+ else if constexpr (std::is_same_v<OpTy , std::greater_equal<void >>)
103103 return x >= y;
104104 else
105105 return Op (x, y);
106106 }(xs...);
107107 } else {
108108 return Op (xs...);
109109 }
110- }(bit_cast<vector_t >(Args)...);
111-
110+ }(bit_cast<vec_t >(Args)...);
112111 if constexpr (std::is_same_v<element_type, bool >) {
113- // vec(vector_t) ctor does a simple bit_cast and the way "bool" is
114- // stored is that only one bit matters. vector_t, however, is a char
115- // type and it can have non-zero value with lowest bit unset. E.g.,
116- // consider this:
117- //
118- // auto x = true + true; // int x = 2
119- // bool y = true + true; // bool y = true
120- //
121- // and the vec<bool, N> has to behave in a similar way. As such, current
122- // implementation needs to do some extra processing for operators that
123- // can result in this scenario.
124- //
112+ // Some operations are known to produce the required bit patterns and
113+ // the following post-processing isn't necessary for them:
125114 if constexpr (!is_logical &&
126- !check_type_in_v<BinOp , std::multiplies<void >,
115+ !check_type_in_v<OpTy , std::multiplies<void >,
127116 std::divides<void >, std::bit_or<void >,
128117 std::bit_and<void >, std::bit_xor<void >,
129118 ShiftRight, UnaryPlus>) {
130- // TODO: Not sure why the following doesn't work
131- // (test-e2e/Basic/vector/bool.cpp fails).
132- //
133- // res = (decltype(res))(res != 0);
134- for (size_t i = 0 ; i < N; ++i)
135- res[i] = bit_cast<int8_t >(res[i]) != 0 ;
119+ // Extra cast is needed because:
120+ static_assert (std::is_same_v<int8_t , signed char >);
121+ static_assert (!std::is_same_v<
122+ decltype (std::declval<ext_vector<int8_t , 2 >>() != 0 ),
123+ ext_vector<int8_t , 2 >>);
124+ static_assert (std::is_same_v<
125+ decltype (std::declval<ext_vector<int8_t , 2 >>() != 0 ),
126+ ext_vector<char , 2 >>);
127+
128+ // `... * -1` is needed because ext_vector_type's comparison follows
129+ // OpenCL binary representation for "true" (-1).
130+ // `std::array<bool, N>` is different and LLVM annotates its
131+ // elements with [0, 2) range metadata when loaded, so we need to
132+ // ensure we generate 0/1 only (and not 2/-1/etc.).
133+ #if __clang_major__ >= 20
134+ // Not an integral constant expression prior to clang-20.
135+ static_assert ((ext_vector<int8_t , 2 >{1 , 0 } == 0 )[1 ] == -1 );
136+ #endif
137+
138+ tmp = reinterpret_cast <decltype (tmp)>((tmp != 0 ) * -1 );
136139 }
137140 }
138- // The following is true:
139- //
140- // using char2 = char __attribute__((ext_vector_type(2)));
141- // using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
142- // static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
143- // std::declval<uchar2>()),
144- // char2>);
145- //
146- // so we need some extra casts. Also, static_cast<uchar2>(char2{})
147- // isn't allowed either.
148- return result_t {(typename result_t ::vector_t )res};
141+ return bit_cast<result_t >(tmp);
149142 }
143+ #endif
144+ #endif
145+ result_t res{};
146+ for (size_t i = 0 ; i < N; ++i)
147+ if constexpr (is_logical)
148+ res[i] = Op (Args[i]...) ? -1 : 0 ;
149+ else
150+ res[i] = Op (Args[i]...);
151+ return res;
150152 }
151153};
152154
0 commit comments