Skip to content
Merged
102 changes: 102 additions & 0 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ struct IncDec {};

template <class T> static constexpr bool not_fp = !is_vgenfloat_v<T>;

#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
// Not using `is_byte_v` to avoid unnecessary dependencies on `half`/`bfloat16`
// headers.
template <class T>
static constexpr bool not_byte =
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
!std::is_same_v<T, std::byte>;
#else
true;
#endif
#endif

// To provide information about operators availability depending on vec/swizzle
// element type.
template <typename Op, typename T>
Expand All @@ -80,6 +92,7 @@ inline constexpr bool is_op_available_for_type<OpAssign<Op>, T> =
inline constexpr bool is_op_available_for_type<OP, T> = COND;

// clang-format off
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
__SYCL_OP_AVAILABILITY(std::plus<void> , true)
__SYCL_OP_AVAILABILITY(std::minus<void> , true)
__SYCL_OP_AVAILABILITY(std::multiplies<void> , true)
Expand Down Expand Up @@ -110,6 +123,38 @@ __SYCL_OP_AVAILABILITY(std::bit_not<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(UnaryPlus , true)

__SYCL_OP_AVAILABILITY(IncDec , true)
#else
__SYCL_OP_AVAILABILITY(std::plus<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::minus<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::multiplies<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::divides<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::modulus<void> , not_fp<T>)

__SYCL_OP_AVAILABILITY(std::bit_and<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(std::bit_or<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(std::bit_xor<void> , not_fp<T>)

__SYCL_OP_AVAILABILITY(std::equal_to<void> , true)
__SYCL_OP_AVAILABILITY(std::not_equal_to<void> , true)
__SYCL_OP_AVAILABILITY(std::less<void> , true)
__SYCL_OP_AVAILABILITY(std::greater<void> , true)
__SYCL_OP_AVAILABILITY(std::less_equal<void> , true)
__SYCL_OP_AVAILABILITY(std::greater_equal<void> , true)

__SYCL_OP_AVAILABILITY(std::logical_and<void> , not_byte<T> && not_fp<T>)
__SYCL_OP_AVAILABILITY(std::logical_or<void> , not_byte<T> && not_fp<T>)

__SYCL_OP_AVAILABILITY(ShiftLeft , not_byte<T> && not_fp<T>)
__SYCL_OP_AVAILABILITY(ShiftRight , not_byte<T> && not_fp<T>)

// Unary
__SYCL_OP_AVAILABILITY(std::negate<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::logical_not<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::bit_not<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(UnaryPlus , not_byte<T>)

__SYCL_OP_AVAILABILITY(IncDec , not_byte<T>)
#endif
// clang-format on

#undef __SYCL_OP_AVAILABILITY
Expand Down Expand Up @@ -188,6 +233,12 @@ template <typename Self> struct VecOperators {
using element_type = typename from_incomplete<Self>::element_type;
static constexpr int N = from_incomplete<Self>::size();

#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename T>
static constexpr bool is_compatible_scalar =
std::is_convertible_v<T, typename from_incomplete<Self>::element_type>;
#endif

template <typename Op>
using result_t = std::conditional_t<
is_logical<Op>, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;
Expand Down Expand Up @@ -293,6 +344,7 @@ template <typename Self> struct VecOperators {
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, IncDec>>>
: public IncDecImpl<Self> {};

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
Expand Down Expand Up @@ -341,13 +393,60 @@ template <typename Self> struct VecOperators {
friend auto operator OPERATOR(const Self &v) { return apply<OP>(v); } \
};

#else

#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
friend result_t<OP> operator OPERATOR(const Self & lhs, \
const Self & rhs) { \
return VecOperators::apply<OP>(lhs, rhs); \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, result_t<OP>> \
operator OPERATOR(const Self & lhs, const T & rhs) { \
return VecOperators::apply<OP>(lhs, Self{static_cast<T>(rhs)}); \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, result_t<OP>> \
operator OPERATOR(const T & lhs, const Self & rhs) { \
return VecOperators::apply<OP>(Self{static_cast<T>(lhs)}, rhs); \
} \
};

#define __SYCL_VEC_OPASSIGN_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OpAssign<OP>>>> { \
friend Self &operator OPERATOR(Self & lhs, const Self & rhs) { \
lhs = OP{}(lhs, rhs); \
return lhs; \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, Self &> \
operator OPERATOR(Self & lhs, const T & rhs) { \
lhs = OP{}(lhs, rhs); \
return lhs; \
} \
};

#define __SYCL_VEC_UOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
friend result_t<OP> operator OPERATOR(const Self & v) { \
return apply<OP>(v); \
} \
};

#endif

__SYCL_INSTANTIATE_OPERATORS(__SYCL_VEC_BINOP_MIXIN,
__SYCL_VEC_OPASSIGN_MIXIN, __SYCL_VEC_UOP_MIXIN)

#undef __SYCL_VEC_UOP_MIXIN
#undef __SYCL_VEC_OPASSIGN_MIXIN
#undef __SYCL_VEC_BINOP_MIXIN

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename Op>
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, std::bit_not<void>>>> {
template <typename T = typename from_incomplete<Self>::element_type>
Expand All @@ -356,6 +455,7 @@ template <typename Self> struct VecOperators {
return apply<std::bit_not<void>>(v);
}
};
#endif

template <typename... Op>
struct __SYCL_EBO CombineImpl : public OpMixin<Op>... {};
Expand All @@ -377,6 +477,7 @@ template <typename Self> struct VecOperators {
OpAssign<ShiftRight>, IncDec> {};
};

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename DataT, int NumElements>
class vec_arith : public VecOperators<vec<DataT, NumElements>>::Combined {};

Expand Down Expand Up @@ -427,6 +528,7 @@ class vec_arith<std::byte, NumElements>
}
};
#endif // (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
#endif

#undef __SYCL_INSTANTIATE_OPERATORS

Expand Down
31 changes: 23 additions & 8 deletions sycl/include/sycl/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,18 @@ template <typename DataT> class vec_base<DataT, 1> {
// Provides a cross-platform vector class template that works efficiently on
// SYCL devices as well as in host C++ code.
template <typename DataT, int NumElements>
class __SYCL_EBO vec
: public detail::vec_arith<DataT, NumElements>,
public detail::ApplyIf<
NumElements == 1,
detail::ScalarConversionOperatorsMixIn<vec<DataT, NumElements>>>,
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
// Keep it last to simplify ABI layout test:
public detail::vec_base<DataT, NumElements> {
class __SYCL_EBO vec :
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
public detail::vec_arith<DataT, NumElements>,
#else
public detail::VecOperators<vec<DataT, NumElements>>::Combined,
#endif
public detail::ApplyIf<
NumElements == 1,
detail::ScalarConversionOperatorsMixIn<vec<DataT, NumElements>>>,
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
// Keep it last to simplify ABI layout test:
public detail::vec_base<DataT, NumElements> {
static_assert(std::is_same_v<DataT, std::remove_cv_t<DataT>>,
"DataT must be cv-unqualified");

Expand Down Expand Up @@ -408,6 +412,7 @@ class __SYCL_EBO vec
constexpr vec &operator=(const vec &) = default;
constexpr vec &operator=(vec &&) = default;

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
// Template required to prevent ambiguous overload with the copy assignment
// when NumElements == 1. The template prevents implicit conversion from
// vec<_, 1> to DataT.
Expand All @@ -427,6 +432,14 @@ class __SYCL_EBO vec
*this = Rhs.template as<vec>();
return *this;
}
#else
template <typename T>
typename std::enable_if_t<std::is_convertible_v<T, DataT>, vec &>
operator=(const T &Rhs) {
*this = vec{static_cast<DataT>(Rhs)};
return *this;
}
#endif

__SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead")
static constexpr size_t get_count() { return size(); }
Expand Down Expand Up @@ -536,8 +549,10 @@ class __SYCL_EBO vec
int... T5>
friend class detail::SwizzleOp;
template <typename T1, int T2> friend class __SYCL_EBO vec;
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
// To allow arithmetic operators access private members of vec.
template <typename T1, int T2> friend class detail::vec_arith;
#endif
};
///////////////////////// class sycl::vec /////////////////////////

Expand Down
2 changes: 2 additions & 0 deletions sycl/test-e2e/Basic/vector/byte.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ int main() {
assert(SwizByte2Neg[0] == ~SwizByte2B[0]);
}

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
{
// std::byte is not an arithmetic type and it only supports the following
// overloads of >> and << operators.
Expand Down Expand Up @@ -207,6 +208,7 @@ int main() {
assert(SwizShiftRight[0] == SwizByte2Shift[0] >> 3 &&
SwizShiftLeft[1] == SwizByte2Shift[1] << 3);
}
#endif
}

return 0;
Expand Down
4 changes: 2 additions & 2 deletions sycl/test-e2e/Basic/vector/vec_binary_scalar_order.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ bool CheckResult(sycl::vec<T1, N> V, T2 Ref) {
constexpr T RefVal = 2; \
VecT InVec{static_cast<T>(RefVal)}; \
{ \
VecT OutVecsDevice[2]; \
ResT OutVecsDevice[2]; \
T OutRefsDevice[2]; \
{ \
sycl::buffer<VecT, 1> OutVecsBuff{OutVecsDevice, 2}; \
sycl::buffer<ResT, 1> OutVecsBuff{OutVecsDevice, 2}; \
sycl::buffer<T, 1> OutRefsBuff{OutRefsDevice, 2}; \
Q.submit([&](sycl::handler &CGH) { \
sycl::accessor OutVecsAcc{OutVecsBuff, CGH, sycl::read_write}; \
Expand Down
16 changes: 8 additions & 8 deletions sycl/test-e2e/DeviceLib/built-ins/vector_integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ int main() {

// abs
{
s::uint2 r{0};
s::int2 r{0};
{
s::buffer<s::uint2, 1> BufR(&r, s::range<1>(1));
s::buffer<s::int2, 1> BufR(&r, s::range<1>(1));
s::queue myQueue;
myQueue.submit([&](s::handler &cgh) {
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
Expand All @@ -214,8 +214,8 @@ int main() {
});
});
}
unsigned int r1 = r.x();
unsigned int r2 = r.y();
int r1 = r.x();
int r2 = r.y();
assert(r1 == 5);
assert(r2 == 2);
}
Expand All @@ -240,9 +240,9 @@ int main() {

// abs_diff
{
s::uint2 r{0};
s::int2 r{0};
{
s::buffer<s::uint2, 1> BufR(&r, s::range<1>(1));
s::buffer<s::int2, 1> BufR(&r, s::range<1>(1));
s::queue myQueue;
myQueue.submit([&](s::handler &cgh) {
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
Expand All @@ -251,8 +251,8 @@ int main() {
});
});
}
unsigned int r1 = r.x();
unsigned int r2 = r.y();
int r1 = r.x();
int r2 = r.y();
assert(r1 == 4);
assert(r2 == 1);
}
Expand Down
4 changes: 2 additions & 2 deletions sycl/test/basic_tests/vectors/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ using sw_double_2 = decltype(std::declval<vec<double, 4>>().swizzle<1, 2>());
// EXCEPT_IN_PREVIEW condition<>

static_assert( std::is_assignable_v<vec<half, 1>, half>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, float>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, double>);
static_assert( std::is_assignable_v<vec<half, 1>, float>);
static_assert( std::is_assignable_v<vec<half, 1>, double>);
static_assert( std::is_assignable_v<vec<half, 1>, vec<half, 1>>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, vec<float, 1>>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, vec<double, 1>>);
Expand Down