diff --git a/sycl/include/sycl/accessor_image.hpp b/sycl/include/sycl/accessor_image.hpp index 420e3468d22ab..f784561bafd95 100644 --- a/sycl/include/sycl/accessor_image.hpp +++ b/sycl/include/sycl/accessor_image.hpp @@ -12,20 +12,21 @@ namespace sycl { inline namespace _V1 { namespace detail { -template struct IsValidCoordDataT; -template struct IsValidCoordDataT<1, T> { - constexpr static bool value = detail::is_contained< - T, detail::type_list>::type::value; +template struct IsValidCoordDataT; +template struct IsValidCoordDataT<1, T, AllowFP> { + constexpr static bool value = + std::is_same_v || + (AllowFP && std::is_same_v); }; -template struct IsValidCoordDataT<2, T> { - constexpr static bool value = detail::is_contained< - T, detail::type_list, - vec>>::type::value; +template struct IsValidCoordDataT<2, T, AllowFP> { + constexpr static bool value = + std::is_same_v> || + (AllowFP && std::is_same_v>); }; -template struct IsValidCoordDataT<3, T> { - constexpr static bool value = detail::is_contained< - T, detail::type_list, - vec>>::type::value; +template struct IsValidCoordDataT<3, T, AllowFP> { + constexpr static bool value = + std::is_same_v> || + (AllowFP && std::is_same_v>); }; template struct IsValidUnsampledCoord2020DataT; @@ -448,12 +449,12 @@ class image_accessor // (accessTarget == access::target::image && accessMode == access::mode::read) // || (accessTarget == access::target::host_image && ( accessMode == // access::mode::read || accessMode == access::mode::read_write)) - template 0) && (IsValidCoordDataT::value) && - (detail::is_genint_v) && - ((IsImageAcc && IsImageAccessReadOnly) || - (IsHostImageAcc && IsImageAccessAnyRead))>> + template < + typename CoordT, int Dims = Dimensions, + typename = std::enable_if_t< + (IsValidCoordDataT::value) && + ((IsImageAcc && IsImageAccessReadOnly) || + (IsHostImageAcc && IsImageAccessAnyRead))>> DataT read(const CoordT &Coords) const { #ifdef __SYCL_DEVICE_ONLY__ return __invoke__ImageRead(MImageObj, Coords); @@ -470,7 +471,7 @@ class image_accessor // access::mode::read || accessMode == access::mode::read_write)) template 0) && (IsValidCoordDataT::value) && + (IsValidCoordDataT::value) && ((IsImageAcc && IsImageAccessReadOnly) || (IsHostImageAcc && IsImageAccessAnyRead))>> DataT read(const CoordT &Coords, const sampler &Smpl) const { @@ -494,10 +495,10 @@ class image_accessor // accessMode == access::mode::read_write)) template < typename CoordT, int Dims = Dimensions, - typename = std::enable_if_t<(Dims > 0) && (detail::is_genint_v) && - (IsValidCoordDataT::value) && - ((IsImageAcc && IsImageAccessWriteOnly) || - (IsHostImageAcc && IsImageAccessAnyWrite))>> + typename = std::enable_if_t< + (IsValidCoordDataT::value) && + ((IsImageAcc && IsImageAccessWriteOnly) || + (IsHostImageAcc && IsImageAccessAnyWrite))>> void write(const CoordT &Coords, const DataT &Color) const { #ifdef __SYCL_DEVICE_ONLY__ __invoke__ImageWrite(MImageObj, Coords, Color); @@ -546,23 +547,21 @@ class __image_array_slice__ { size_t Idx) : MBaseAcc(BaseAcc), MIdx(Idx) {} - template 0) && (IsValidCoordDataT::value)>> + template < + typename CoordT, int Dims = Dimensions, + typename = std::enable_if_t<(IsValidCoordDataT::value)>> DataT read(const CoordT &Coords) const { return MBaseAcc.read(getAdjustedCoords(Coords)); } template 0) && - IsValidCoordDataT::value>> + typename = std::enable_if_t::value>> DataT read(const CoordT &Coords, const sampler &Smpl) const { return MBaseAcc.read(getAdjustedCoords(Coords), Smpl); } template 0) && - IsValidCoordDataT::value>> + typename = std::enable_if_t::value>> void write(const CoordT &Coords, const DataT &Color) const { return MBaseAcc.write(getAdjustedCoords(Coords), Color); } diff --git a/sycl/include/sycl/builtins_utils_vec.hpp b/sycl/include/sycl/builtins_utils_vec.hpp index eeaff9450b031..178c696495c8e 100644 --- a/sycl/include/sycl/builtins_utils_vec.hpp +++ b/sycl/include/sycl/builtins_utils_vec.hpp @@ -8,6 +8,8 @@ #pragma once +#include + #include #include @@ -41,19 +43,6 @@ template struct is_valid_elem_type : std::bool_constant> {}; -// Utility trait for getting the number of elements in T. -template -struct num_elements : std::integral_constant {}; -template -struct num_elements> : std::integral_constant {}; -template -struct num_elements> : std::integral_constant {}; -template class OperationCurrentT, int... Indexes> -struct num_elements> - : std::integral_constant {}; - // Utilty trait for checking that the number of elements in T is in Ns. template struct is_valid_size diff --git a/sycl/include/sycl/detail/generic_type_lists.hpp b/sycl/include/sycl/detail/generic_type_lists.hpp deleted file mode 100644 index c6a62504efbec..0000000000000 --- a/sycl/include/sycl/detail/generic_type_lists.hpp +++ /dev/null @@ -1,514 +0,0 @@ -//==-------- generic_type_lists.hpp - SYCL Generic type lists --------------==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include // for address_space -#include // for type_list, address_space_list - -#include // for byte, size_t -#include // for conditional_t, is_signed_v, is_... - -// Generic type name description, which serves as a description for all valid -// types of parameters to kernel functions - -// Forward declarations -namespace sycl { -inline namespace _V1 { -template class __SYCL_EBO vec; -template class marray; - -namespace detail { -namespace half_impl { -class half; -} -} // namespace detail -using half = detail::half_impl::half; - -namespace ext::oneapi { -class bfloat16; -} -namespace detail { -namespace gtl { -// floating point types -using scalar_half_list = type_list; - -using vector_half_list = type_list, vec, vec, - vec, vec, vec>; - -using marray_half_list = - type_list, marray, marray, - marray, marray, marray>; - -using scalar_vector_half_list = tl_append; - -using half_list = - tl_append; - -using scalar_bfloat16_list = type_list; - -using vector_bfloat16_list = type_list< - vec, vec, - vec, vec, - vec, vec>; - -using marray_bfloat16_list = type_list, - marray, - marray, - marray, - marray, - marray>; - -using scalar_vector_bfloat16_list = - tl_append; - -using bfloat16_list = - tl_append; - -using half_bfloat16_list = tl_append; - -using scalar_float_list = type_list; - -using vector_float_list = - type_list, vec, vec, vec, - vec, vec>; - -using marray_float_list = - type_list, marray, marray, - marray, marray, marray>; - -using scalar_vector_float_list = - tl_append; - -using float_list = - tl_append; - -using scalar_double_list = type_list; - -using vector_double_list = - type_list, vec, vec, vec, - vec, vec>; - -using marray_double_list = - type_list, marray, marray, - marray, marray, marray>; - -using scalar_vector_double_list = - tl_append; - -using double_list = - tl_append; - -using scalar_floating_list = tl_append; - -using vector_floating_list = tl_append; - -using marray_floating_list = tl_append; - -using scalar_vector_floating_list = - tl_append; - -using floating_list = - tl_append; - -using scalar_default_char_list = type_list; - -using vector_default_char_list = - type_list, vec, vec, vec, - vec, vec>; - -using marray_default_char_list = - type_list, marray, marray, - marray, marray, marray>; - -using default_char_list = - tl_append; - -using scalar_signed_char_list = type_list; - -using vector_signed_char_list = - type_list, vec, vec, - vec, vec, vec>; - -using marray_signed_char_list = - type_list, marray, - marray, marray, - marray, marray>; - -using scalar_unsigned_char_list = type_list; - -using vector_unsigned_char_list = - type_list, vec, - vec, vec, - vec, vec>; - -using marray_unsigned_char_list = - type_list, marray, - marray, marray, - marray, marray>; - -// short int types -using scalar_signed_short_list = type_list; - -using vector_signed_short_list = - type_list, vec, vec, - vec, vec, - vec>; - -using marray_signed_short_list = - type_list, marray, - marray, marray, - marray, marray>; - -using scalar_unsigned_short_list = type_list; - -using vector_unsigned_short_list = - type_list, vec, - vec, vec, - vec, vec>; - -using marray_unsigned_short_list = - type_list, marray, - marray, marray, - marray, marray>; - -using unsigned_short_list = - tl_append; - -using scalar_short_list = - tl_append; - -using vector_short_list = - tl_append; - -using short_list = tl_append; - -// int types -using scalar_signed_int_list = type_list; - -using vector_signed_int_list = - type_list, vec, vec, - vec, vec, vec>; - -using marray_signed_int_list = - type_list, marray, - marray, marray, - marray, marray>; - -using signed_int_list = - tl_append; - -using scalar_unsigned_int_list = type_list; - -using vector_unsigned_int_list = - type_list, vec, vec, - vec, vec, - vec>; - -using marray_unsigned_int_list = - type_list, marray, - marray, marray, - marray, marray>; - -using unsigned_int_list = - tl_append; - -using scalar_int_list = - tl_append; - -using vector_int_list = - tl_append; - -using marray_int_list = - tl_append; - -using int_list = tl_append; - -// long types -using scalar_signed_long_list = type_list; - -using vector_signed_long_list = - type_list, vec, vec, - vec, vec, vec>; - -using marray_signed_long_list = - type_list, marray, - marray, marray, - marray, marray>; - -using signed_long_list = - tl_append; - -using scalar_unsigned_long_list = type_list; - -using vector_unsigned_long_list = - type_list, vec, - vec, vec, - vec, vec>; - -using marray_unsigned_long_list = - type_list, marray, - marray, marray, - marray, marray>; - -using unsigned_long_list = - tl_append; - -using scalar_long_list = - tl_append; - -using vector_long_list = - tl_append; - -using marray_long_list = - tl_append; - -using long_list = - tl_append; - -// long long types -using scalar_signed_longlong_list = type_list; - -using vector_signed_longlong_list = - type_list, vec, - vec, vec, - vec, vec>; - -using marray_signed_longlong_list = - type_list, marray, - marray, marray, - marray, marray>; - -using signed_longlong_list = - tl_append; - -using scalar_unsigned_longlong_list = type_list; - -using vector_unsigned_longlong_list = - type_list, vec, - vec, vec, - vec, vec>; - -using marray_unsigned_longlong_list = - type_list, marray, - marray, marray, - marray, marray>; - -using unsigned_longlong_list = - tl_append; - -using scalar_longlong_list = - tl_append; - -using vector_longlong_list = - tl_append; - -using marray_longlong_list = - tl_append; - -using longlong_list = - tl_append; - -// long integer types -using scalar_signed_long_integer_list = - tl_append; - -using vector_signed_long_integer_list = - tl_append; - -using marray_signed_long_integer_list = - tl_append; - -using signed_long_integer_list = - tl_append; - -using scalar_unsigned_long_integer_list = - tl_append; - -using vector_unsigned_long_integer_list = - tl_append; - -using marray_unsigned_long_integer_list = - tl_append; - -using unsigned_long_integer_list = tl_append; - -using scalar_long_integer_list = tl_append; - -using vector_long_integer_list = tl_append; - -using marray_long_integer_list = tl_append; - -using long_integer_list = - tl_append; - -#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) -// std::byte -using scalar_byte_list = type_list; - -using vector_byte_list = - type_list, vec, vec, - vec, vec, vec>; - -using marray_byte_list = type_list, marray, - marray, marray, - marray, marray>; -#endif - -// integer types -using scalar_signed_integer_list = - tl_append, - tl_append, - scalar_signed_char_list>, - scalar_signed_short_list, scalar_signed_int_list, - scalar_signed_long_list, scalar_signed_longlong_list>; - -using vector_signed_integer_list = - tl_append, - tl_append, - vector_signed_char_list>, - vector_signed_short_list, vector_signed_int_list, - vector_signed_long_list, vector_signed_longlong_list>; - -using marray_signed_integer_list = - tl_append, - tl_append, - marray_signed_char_list>, - marray_signed_short_list, marray_signed_int_list, - marray_signed_long_list, marray_signed_longlong_list>; - -using signed_integer_list = - tl_append; - -using scalar_unsigned_integer_list = - tl_append, - tl_append, - scalar_unsigned_char_list>, - scalar_unsigned_short_list, scalar_unsigned_int_list, - scalar_unsigned_long_list, scalar_unsigned_longlong_list -#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) - , - scalar_byte_list -#endif - >; - -using vector_unsigned_integer_list = - tl_append, - tl_append, - vector_unsigned_char_list>, - vector_unsigned_short_list, vector_unsigned_int_list, - vector_unsigned_long_list, vector_unsigned_longlong_list -#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) - , - vector_byte_list -#endif - >; - -using marray_unsigned_integer_list = - tl_append, - tl_append, - marray_unsigned_char_list>, - marray_unsigned_short_list, marray_unsigned_int_list, - marray_unsigned_long_list, marray_unsigned_longlong_list -#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) - , - marray_byte_list -#endif - >; - -using unsigned_integer_list = - tl_append; - -using scalar_integer_list = - tl_append; - -using vector_integer_list = - tl_append; - -using marray_integer_list = - tl_append; - -using integer_list = - tl_append; - -// bool types - -using marray_bool_list = - type_list, marray, marray, - marray, marray, marray>; - -using scalar_bool_list = type_list; - -using bool_list = tl_append; - -using vector_bool_list = type_list, vec, vec, - vec, vec, vec>; - -// basic types -using scalar_signed_basic_list = - tl_append; - -using vector_signed_basic_list = - tl_append; - -using marray_signed_basic_list = - tl_append; - -using signed_basic_list = - tl_append; - -using scalar_unsigned_basic_list = tl_append; - -using unsigned_basic_list = - tl_append; - -using vector_basic_list = - tl_append; - -} // namespace gtl -namespace gvl { -// address spaces -using nonconst_address_space_list = address_space_list< - access::address_space::local_space, access::address_space::global_space, - access::address_space::private_space, access::address_space::generic_space, - access::address_space::ext_intel_global_device_space, - access::address_space::ext_intel_global_host_space>; - -} // namespace gvl -} // namespace detail -} // namespace _V1 -} // namespace sycl diff --git a/sycl/include/sycl/detail/generic_type_traits.hpp b/sycl/include/sycl/detail/generic_type_traits.hpp index aafef77a12067..8f977e2247a44 100644 --- a/sycl/include/sycl/detail/generic_type_traits.hpp +++ b/sycl/include/sycl/detail/generic_type_traits.hpp @@ -10,7 +10,6 @@ #include // for decorated, address_space #include // for half, cl_char, cl_double -#include // for nonconst_address_space... #include // for marray #include // for is_contained, find_sam... #include // for is_gen_based_on_type_s... @@ -28,76 +27,56 @@ namespace sycl { inline namespace _V1 { namespace detail { template -inline constexpr bool is_svgenfloatf_v = - is_contained_v; - -template -inline constexpr bool is_svgenfloatd_v = - is_contained_v; - -template -inline constexpr bool is_half_v = is_contained_v; - -template -inline constexpr bool is_bfloat16_v = - is_contained_v; - -template -inline constexpr bool is_half_or_bf16_v = - is_contained_v; - -template -inline constexpr bool is_svgenfloath_v = - is_contained_v; +using is_byte = typename +#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) + std::is_same; +#else + std::false_type; +#endif -template -inline constexpr bool is_genfloat_v = is_contained_v; +template inline constexpr bool is_byte_v = is_byte::value; template inline constexpr bool is_sgenfloat_v = - is_contained_v; + check_type_in_v; template inline constexpr bool is_vgenfloat_v = - is_contained_v; + is_vec_v && is_sgenfloat_v>; template -inline constexpr bool is_svgenfloat_v = - is_contained_v; - -template -inline constexpr bool is_genint_v = is_contained_v; - -template -inline constexpr bool is_geninteger_v = is_contained_v; - -template -using is_geninteger = std::bool_constant>; - -template -inline constexpr bool is_sgeninteger_v = - is_contained_v; +inline constexpr bool is_genfloat_v = + is_sgenfloat_v || is_vgenfloat_v || + (is_marray_v && is_sgenfloat_v> && + is_allowed_vec_size_v>); template inline constexpr bool is_sigeninteger_v = - is_contained_v; + check_type_in_v || + (std::is_same_v && std::is_signed_v); template inline constexpr bool is_sugeninteger_v = - is_contained_v; + check_type_in_v || + (std::is_same_v && std::is_unsigned_v) || is_byte_v; template -inline constexpr bool is_genbool_v = is_contained_v; +inline constexpr bool is_sgeninteger_v = + is_sigeninteger_v || is_sugeninteger_v; template -using is_byte = typename -#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) - std::is_same; -#else - std::false_type; -#endif +inline constexpr bool is_geninteger_v = + is_sgeninteger_v || + (is_vec_v && is_sgeninteger_v>) || + (is_marray_v && is_sgeninteger_v> && + is_allowed_vec_size_v>); -template inline constexpr bool is_byte_v = is_byte::value; +template +inline constexpr bool is_genbool_v = + std::is_same_v || + (is_marray_v && std::is_same_v, bool> && + is_allowed_vec_size_v>); template using fixed_width_unsigned = std::conditional_t< @@ -152,10 +131,11 @@ template auto convertToOpenCLType(T &&x) { // sycl::half may convert to _Float16, and we would try to instantiate // vec class with _Float16 DataType, which is not expected there. As // such, leave vector as-is. - using MatchingVec = vec, ElemTy, - decltype(convertToOpenCLType( - std::declval()))>, - no_ref::size()>; + using MatchingVec = + vec, ElemTy, + decltype(convertToOpenCLType( + std::declval()))>, + no_ref::size()>; #ifdef __SYCL_DEVICE_ONLY__ return sycl::bit_cast(x); #else @@ -171,11 +151,11 @@ template auto convertToOpenCLType(T &&x) { fixed_width_unsigned>; static_assert(sizeof(OpenCLType) == sizeof(T)); return static_cast(x); - } else if constexpr (is_half_v) { + } else if constexpr (std::is_same_v) { using OpenCLType = sycl::detail::half_impl::BIsRepresentationT; static_assert(sizeof(OpenCLType) == sizeof(T)); return static_cast(x); - } else if constexpr (is_bfloat16_v) { + } else if constexpr (std::is_same_v) { // On host, don't interpret BF16 as uint16. #ifdef __SYCL_DEVICE_ONLY__ using OpenCLType = sycl::ext::oneapi::detail::Bfloat16StorageT; diff --git a/sycl/include/sycl/detail/type_traits.hpp b/sycl/include/sycl/detail/type_traits.hpp index 90790633faa9f..df5f9ea53a045 100644 --- a/sycl/include/sycl/detail/type_traits.hpp +++ b/sycl/include/sycl/detail/type_traits.hpp @@ -8,8 +8,9 @@ #pragma once +#include + #include // for decorated, address_space -#include // for vec, marray, integer_list #include // for is_contained, find_twi... #include // for array @@ -171,18 +172,6 @@ template struct get_elem_type_unqual { using type = ElementType; }; -template -struct is_ext_vector : std::false_type {}; - -// FIXME: unguarded use of non-standard built-in -template -struct is_ext_vector< - T, std::void_t()))>> - : std::true_type {}; - -template -inline constexpr bool is_ext_vector_v = is_ext_vector::value; - // FIXME: unguarded use of non-standard built-in template struct get_elem_type_unqual>> { @@ -255,11 +244,6 @@ template class S> inline constexpr bool is_gen_based_on_type_sizeof_v = S::value && (sizeof(vector_element_t) == N); -template struct is_vec : std::false_type {}; -template struct is_vec> : std::true_type {}; - -template constexpr bool is_vec_v = is_vec::value; - template struct get_vec_size { static constexpr int size = 1; }; @@ -268,27 +252,6 @@ template struct get_vec_size> { static constexpr int size = N; }; -// is_swizzle -template struct is_swizzle : std::false_type {}; -template class OperationCurrentT, int... Indexes> -struct is_swizzle> : std::true_type {}; - -template constexpr bool is_swizzle_v = is_swizzle::value; - -// is_swizzle_or_vec_v - -template -constexpr bool is_vec_or_swizzle_v = is_vec_v || is_swizzle_v; - -// is_marray -template struct is_marray : std::false_type {}; -template -struct is_marray> : std::true_type {}; - -template constexpr bool is_marray_v = is_marray::value; - // is_integral template struct is_integral : std::is_integral> {}; diff --git a/sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp b/sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp new file mode 100644 index 0000000000000..8097451352e75 --- /dev/null +++ b/sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp @@ -0,0 +1,108 @@ +//==---------- Forward declarations and traits for vector/marray types -----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +namespace sycl { +inline namespace _V1 { +template class __SYCL_EBO vec; + +template class marray; + +namespace detail { +template class OperationCurrentT, int... Indexes> +class SwizzleOp; + +// --------- is_* traits ------------------ // +template struct is_vec : std::false_type {}; +template struct is_vec> : std::true_type {}; +template constexpr bool is_vec_v = is_vec::value; + +template +struct is_ext_vector : std::false_type {}; +#if defined(__has_extension) +#if __has_extension(attribute_ext_vector_type) +template +struct is_ext_vector : std::true_type {}; +#endif +#endif +template +inline constexpr bool is_ext_vector_v = is_ext_vector::value; + +template struct is_swizzle : std::false_type {}; +template class OperationCurrentT, int... Indexes> +struct is_swizzle> : std::true_type {}; +template constexpr bool is_swizzle_v = is_swizzle::value; + +template +constexpr bool is_vec_or_swizzle_v = is_vec_v || is_swizzle_v; + +template struct is_marray : std::false_type {}; +template +struct is_marray> : std::true_type {}; +template constexpr bool is_marray_v = is_marray::value; + +// --------- num_elements trait ------------------ // +template +struct num_elements : std::integral_constant {}; +template +struct num_elements> : std::integral_constant {}; +template +struct num_elements> + : std::integral_constant {}; +#if defined(__has_extension) +#if __has_extension(attribute_ext_vector_type) +template +struct num_elements + : std::integral_constant {}; +#endif +#endif +template class OperationCurrentT, int... Indexes> +struct num_elements> + : std::integral_constant {}; + +template +inline constexpr std::size_t num_elements_v = num_elements::value; + +// --------- element_type trait ------------------ // +template struct element_type { + using type = T; +}; +template struct element_type> { + using type = T; +}; +template struct element_type> { + using type = T; +}; +#if defined(__has_extension) +#if __has_extension(attribute_ext_vector_type) +template +struct element_type { + using type = T; +}; +#endif +#endif +template using element_type_t = typename element_type::type; + +template +inline constexpr bool is_allowed_vec_size_v = + N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16; + +} // namespace detail +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp b/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp index ee1bea39cae69..4352705693730 100644 --- a/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp @@ -11,7 +11,6 @@ #include #include #include -#include #include #include diff --git a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp index facc486ca2f84..28d6864750855 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp @@ -95,9 +95,13 @@ namespace native { // sycl::native::tanh is only implemented on nvptx backend so far. For other // backends we revert to the sycl::tanh impl. template -inline __SYCL_ALWAYS_INLINE std::enable_if_t< - sycl::detail::is_svgenfloatf_v || sycl::detail::is_svgenfloath_v, T> -tanh(T x) __NOEXC { +inline __SYCL_ALWAYS_INLINE + std::enable_if_t || std::is_same_v || + (detail::is_vec_v && + (std::is_same_v, float> || + std::is_same_v, half>)), + T> + tanh(T x) __NOEXC { #if defined(__NVPTX__) return sycl::detail::convertFromOpenCLTypeFor( __clc_native_tanh(sycl::detail::convertToOpenCLType(x))); @@ -144,7 +148,10 @@ inline __SYCL_ALWAYS_INLINE // For other backends we revert to the sycl::exp2 impl. template inline __SYCL_ALWAYS_INLINE - std::enable_if_t, T> + std::enable_if_t || + (detail::is_vec_v && + std::is_same_v, half>), + T> exp2(T x) __NOEXC { #if defined(__NVPTX__) return sycl::detail::convertFromOpenCLTypeFor( diff --git a/sycl/include/sycl/ext/oneapi/experimental/cuda/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/cuda/builtins.hpp index 06d5318c525db..f96878a123acf 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/cuda/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/cuda/builtins.hpp @@ -44,17 +44,13 @@ using ldg_vector_types = sycl::detail::type_list< sycl::vec, sycl::vec, sycl::vec, sycl::vec, sycl::vec, sycl::vec, sycl::vec, sycl::vec, sycl::vec>; - -using ldg_types = - sycl::detail::tl_append; } // namespace detail template inline __SYCL_ALWAYS_INLINE std::enable_if_t< - sycl::detail::is_contained< - T, sycl::ext::oneapi::experimental::cuda::detail::ldg_types>::value, + sycl::detail::is_sgeninteger_v || sycl::detail::is_sgenfloat_v || + sycl::detail::is_contained::value, T> ldg(const T *ptr) { #if defined(__SYCL_DEVICE_ONLY__) diff --git a/sycl/include/sycl/types.hpp b/sycl/include/sycl/types.hpp index 87698c835a3ee..599ba89819241 100644 --- a/sycl/include/sycl/types.hpp +++ b/sycl/include/sycl/types.hpp @@ -12,7 +12,6 @@ #include // for half, cl_char, cl_int #include // for ArrayCreator, RepeatV... #include // for __SYCL2020_DEPRECATED -#include // for vector_basic_list #include // for is_sigeninteger, is_s... #include #include // for is_contained diff --git a/sycl/include/sycl/vector.hpp b/sycl/include/sycl/vector.hpp index c8d2273094c60..01e70f639e7b5 100644 --- a/sycl/include/sycl/vector.hpp +++ b/sycl/include/sycl/vector.hpp @@ -30,10 +30,8 @@ #include // for half, cl_char, cl_int #include // for ArrayCreator, RepeatV... #include // for __SYCL2020_DEPRECATED -#include // for vector_basic_list #include // for is_sigeninteger, is_s... #include // for memcpy -#include // for is_contained #include // for is_floating_point #include #include // for StorageT, half, Vec16... @@ -121,6 +119,11 @@ struct ScalarConversionOperatorMixIn> { operator T() const { return (*static_cast(this))[0]; } }; +template +inline constexpr bool is_fundamental_or_half_or_bfloat16 = + std::is_fundamental_v || std::is_same_v, half> || + std::is_same_v, ext::oneapi::bfloat16>; + } // namespace detail ///////////////////////// class sycl::vec ///////////////////////// @@ -134,8 +137,7 @@ class __SYCL_EBO vec static_assert(std::is_same_v>, "DataT must be cv-unqualified"); - static_assert(NumElements == 1 || NumElements == 2 || NumElements == 3 || - NumElements == 4 || NumElements == 8 || NumElements == 16, + static_assert(detail::is_allowed_vec_size_v, "Invalid number of elements for sycl::vec: only 1, 2, 3, 4, 8 " "or 16 are supported"); static_assert(sizeof(bool) == sizeof(uint8_t), "bool size is not 1 byte"); @@ -290,10 +292,8 @@ class __SYCL_EBO vec // when NumElements == 1. The template prevents implicit conversion from // vec<_, 1> to DataT. template - typename std::enable_if_t< - std::is_fundamental_v || - detail::is_half_or_bf16_v>, - vec &> + typename std::enable_if_t, + vec &> operator=(const DataT &Rhs) { *this = vec{Rhs}; return *this; @@ -628,16 +628,14 @@ class SwizzleOp { 1 != IdxNum && SwizzleOp::getNumElements() == IdxNum, T>; template - using EnableIfScalarType = typename std::enable_if_t< - std::is_convertible_v && - (std::is_fundamental_v || - detail::is_half_or_bf16_v>)>; + using EnableIfScalarType = + typename std::enable_if_t && + detail::is_fundamental_or_half_or_bfloat16>; template - using EnableIfNoScalarType = typename std::enable_if_t< - !std::is_convertible_v || - !(std::is_fundamental_v || - detail::is_half_or_bf16_v>)>; + using EnableIfNoScalarType = + typename std::enable_if_t || + !detail::is_fundamental_or_half_or_bfloat16>; template using Swizzle = @@ -1242,11 +1240,9 @@ class SwizzleOp { static_assert((sizeof(Tmp) == sizeof(asT)), "The new SYCL vec type must have the same storage size in " "bytes as this SYCL swizzled vec"); - static_assert( - detail::is_contained::value || - detail::is_contained::value, - "asT must be SYCL vec of a different element type and " - "number of elements specified by asT"); + static_assert(detail::is_vec_v, + "asT must be SYCL vec of a different element type and " + "number of elements specified by asT"); return Tmp.template as(); } diff --git a/sycl/test/basic_tests/generic_type_traits.cpp b/sycl/test/basic_tests/generic_type_traits.cpp index 853473e3c61f3..3586d25cab588 100644 --- a/sycl/test/basic_tests/generic_type_traits.cpp +++ b/sycl/test/basic_tests/generic_type_traits.cpp @@ -23,12 +23,6 @@ int main() { static_assert(d::is_genfloat_v == true); static_assert(d::is_genfloat_v> == true); - static_assert(d::is_half_v); - - static_assert(d::is_bfloat16_v); - static_assert(d::is_half_or_bf16_v); - static_assert(d::is_half_or_bf16_v); - // TODO add checks for the following type traits /* is_doublen diff --git a/sycl/test/include_deps/sycl_accessor.hpp.cpp b/sycl/test/include_deps/sycl_accessor.hpp.cpp index 2dce0b3d7696b..622ea90da7006 100644 --- a/sycl/test/include_deps/sycl_accessor.hpp.cpp +++ b/sycl/test/include_deps/sycl_accessor.hpp.cpp @@ -24,7 +24,7 @@ // CHECK-NEXT: info/aspects.def // CHECK-NEXT: info/aspects_deprecated.def // CHECK-NEXT: detail/type_traits.hpp -// CHECK-NEXT: detail/generic_type_lists.hpp +// CHECK-NEXT: detail/type_traits/vec_marray_traits.hpp // CHECK-NEXT: detail/type_list.hpp // CHECK-NEXT: detail/boost/mp11/algorithm.hpp // CHECK-NEXT: detail/boost/mp11/list.hpp @@ -112,5 +112,5 @@ // CHECK-NEXT: detail/string_view.hpp // CHECK-NEXT: detail/util.hpp // CHECK-NEXT: device_selector.hpp -// CHECK-NEXT: buffer_properties.def +// CHECK-NEXT: properties/buffer_properties.def // CHECK-EMPTY: diff --git a/sycl/test/include_deps/sycl_detail_core.hpp.cpp b/sycl/test/include_deps/sycl_detail_core.hpp.cpp index 917996b10cfa6..c35c4b558d7d9 100644 --- a/sycl/test/include_deps/sycl_detail_core.hpp.cpp +++ b/sycl/test/include_deps/sycl_detail_core.hpp.cpp @@ -25,7 +25,7 @@ // CHECK-NEXT: info/aspects.def // CHECK-NEXT: info/aspects_deprecated.def // CHECK-NEXT: detail/type_traits.hpp -// CHECK-NEXT: detail/generic_type_lists.hpp +// CHECK-NEXT: detail/type_traits/vec_marray_traits.hpp // CHECK-NEXT: detail/type_list.hpp // CHECK-NEXT: detail/boost/mp11/algorithm.hpp // CHECK-NEXT: detail/boost/mp11/list.hpp