88
99#pragma once
1010
11- #include < sycl/aliases.hpp>
12- #include < sycl/detail/generic_type_traits.hpp>
13- #include < sycl/detail/type_traits.hpp>
14- #include < sycl/detail/type_traits/vec_marray_traits.hpp>
15- #include < sycl/ext/oneapi/bfloat16.hpp>
11+ #include < sycl/aliases.hpp> // for half, cl_char, cl_int
12+ #include < sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
13+ #include < sycl/detail/type_traits.hpp> // for is_floating_point
14+
15+ #include < sycl/ext/oneapi/bfloat16.hpp> // bfloat16
16+
17+ #include < cstddef>
18+ #include < type_traits> // for enable_if_t, is_same
1619
1720namespace sycl {
1821inline namespace _V1 {
@@ -47,7 +50,13 @@ struct UnaryPlus {
4750};
4851
4952struct VecOperators {
50- template <typename OpTy, typename ... ArgTys>
53+ #ifdef __SYCL_DEVICE_ONLY__
54+ static constexpr bool is_host = false ;
55+ #else
56+ static constexpr bool is_host = true ;
57+ #endif
58+
59+ template <typename BinOp, typename ... ArgTys>
5160 static constexpr auto apply (const ArgTys &...Args) {
5261 using Self = nth_type_t <0 , ArgTys...>;
5362 static_assert (is_vec_v<Self>);
@@ -56,96 +65,88 @@ struct VecOperators {
5665 using element_type = typename Self::element_type;
5766 constexpr int N = Self::size ();
5867 constexpr bool is_logical = check_type_in_v<
59- OpTy , std::equal_to<void >, std::not_equal_to<void >, std::less<void >,
68+ BinOp , std::equal_to<void >, std::not_equal_to<void >, std::less<void >,
6069 std::greater<void >, std::less_equal<void >, std::greater_equal<void >,
6170 std::logical_and<void >, std::logical_or<void >, std::logical_not<void >>;
6271
6372 using result_t = std::conditional_t <
6473 is_logical, vec<fixed_width_signed<sizeof (element_type)>, N>, Self>;
6574
66- OpTy Op{};
67- #ifdef __has_extension
68- #if __has_extension(attribute_ext_vector_type)
69- // ext_vector_type's bool vectors are mapped onto <N x i1> and have
70- // different memory layout than sycl::vec<bool ,N> (which has 1 byte per
71- // element). As such we perform operation on int8_t and then need to
72- // create bit pattern that can be bit-casted back to the original
73- // sycl::vec<bool, N>. This is a hack actually, but we've been doing
74- // that for a long time using sycl::vec::vector_t type.
75- using vec_elem_ty =
76- typename detail::map_type<element_type, //
77- bool , /* ->*/ std::int8_t ,
78- #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
79- std::byte, /* ->*/ std::uint8_t ,
80- #endif
81- #ifdef __SYCL_DEVICE_ONLY__
82- half, /* ->*/ _Float16,
83- #endif
84- element_type, /* ->*/ element_type>::type;
85- if constexpr (N != 1 &&
86- detail::is_valid_type_for_ext_vector_v<vec_elem_ty>) {
87- using vec_t = ext_vector<vec_elem_ty, N>;
88- auto tmp = [&](auto ... xs) {
75+ BinOp Op{};
76+ if constexpr (is_host || N == 1 ||
77+ std::is_same_v<element_type, ext::oneapi::bfloat16>) {
78+ result_t res{};
79+ for (size_t i = 0 ; i < N; ++i)
80+ if constexpr (is_logical)
81+ res[i] = Op (Args[i]...) ? -1 : 0 ;
82+ else
83+ res[i] = Op (Args[i]...);
84+ return res;
85+ } else {
86+ using vector_t = typename Self::vector_t ;
87+
88+ auto res = [&](auto ... xs) {
8989 // Workaround for https://github.com/llvm/llvm-project/issues/119617.
9090 if constexpr (sizeof ...(Args) == 2 ) {
9191 return [&](auto x, auto y) {
92- if constexpr (std::is_same_v<OpTy , std::equal_to<void >>)
92+ if constexpr (std::is_same_v<BinOp , std::equal_to<void >>)
9393 return x == y;
94- else if constexpr (std::is_same_v<OpTy , std::not_equal_to<void >>)
94+ else if constexpr (std::is_same_v<BinOp , std::not_equal_to<void >>)
9595 return x != y;
96- else if constexpr (std::is_same_v<OpTy , std::less<void >>)
96+ else if constexpr (std::is_same_v<BinOp , std::less<void >>)
9797 return x < y;
98- else if constexpr (std::is_same_v<OpTy , std::less_equal<void >>)
98+ else if constexpr (std::is_same_v<BinOp , std::less_equal<void >>)
9999 return x <= y;
100- else if constexpr (std::is_same_v<OpTy , std::greater<void >>)
100+ else if constexpr (std::is_same_v<BinOp , std::greater<void >>)
101101 return x > y;
102- else if constexpr (std::is_same_v<OpTy , std::greater_equal<void >>)
102+ else if constexpr (std::is_same_v<BinOp , std::greater_equal<void >>)
103103 return x >= y;
104104 else
105105 return Op (x, y);
106106 }(xs...);
107107 } else {
108108 return Op (xs...);
109109 }
110- }(bit_cast<vec_t >(Args)...);
110+ }(bit_cast<vector_t >(Args)...);
111+
111112 if constexpr (std::is_same_v<element_type, bool >) {
112- // Some operations are known to produce the required bit patterns and
113- // the following post-processing isn't necessary for them:
113+ // vec(vector_t) ctor does a simple bit_cast and the way "bool" is
114+ // stored is that only one bit matters. vector_t, however, is a char
115+ // type and it can have non-zero value with lowest bit unset. E.g.,
116+ // consider this:
117+ //
118+ // auto x = true + true; // int x = 2
119+ // bool y = true + true; // bool y = true
120+ //
121+ // and the vec<bool, N> has to behave in a similar way. As such, current
122+ // implementation needs to do some extra processing for operators that
123+ // can result in this scenario.
124+ //
114125 if constexpr (!is_logical &&
115- !check_type_in_v<OpTy , std::multiplies<void >,
126+ !check_type_in_v<BinOp , std::multiplies<void >,
116127 std::divides<void >, std::bit_or<void >,
117128 std::bit_and<void >, std::bit_xor<void >,
118129 ShiftRight, UnaryPlus>) {
119- // Extra cast is needed because:
120- static_assert (std::is_same_v<int8_t , signed char >);
121- static_assert (!std::is_same_v<
122- decltype (std::declval<ext_vector<int8_t , 2 >>() != 0 ),
123- ext_vector<int8_t , 2 >>);
124- static_assert (std::is_same_v<
125- decltype (std::declval<ext_vector<int8_t , 2 >>() != 0 ),
126- ext_vector<char , 2 >>);
127-
128- // `... * -1` is needed because ext_vector_type's comparison follows
129- // OpenCL binary representation for "true" (-1).
130- // `std::array<bool, N>` is different and LLVM annotates its
131- // elements with [0, 2) range metadata when loaded, so we need to
132- // ensure we generate 0/1 only (and not 2/-1/etc.).
133- static_assert ((ext_vector<int8_t , 2 >{1 , 0 } == 0 )[1 ] == -1 );
134-
135- tmp = reinterpret_cast <decltype (tmp)>((tmp != 0 ) * -1 );
130+ // TODO: Not sure why the following doesn't work
131+ // (test-e2e/Basic/vector/bool.cpp fails).
132+ //
133+ // res = (decltype(res))(res != 0);
134+ for (size_t i = 0 ; i < N; ++i)
135+ res[i] = bit_cast<int8_t >(res[i]) != 0 ;
136136 }
137137 }
138- return bit_cast<result_t >(tmp);
138+ // The following is true:
139+ //
140+ // using char2 = char __attribute__((ext_vector_type(2)));
141+ // using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
142+ // static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
143+ // std::declval<uchar2>()),
144+ // char2>);
145+ //
146+ // so we need some extra casts. Also, static_cast<uchar2>(char2{})
147+ // isn't allowed either.
148+ return result_t {(typename result_t ::vector_t )res};
139149 }
140- #endif
141- #endif
142- result_t res{};
143- for (size_t i = 0 ; i < N; ++i)
144- if constexpr (is_logical)
145- res[i] = Op (Args[i]...) ? -1 : 0 ;
146- else
147- res[i] = Op (Args[i]...);
148- return res;
149150 }
150151};
151152
0 commit comments