diff --git a/sycl/include/sycl/__spirv/spirv_ops.hpp b/sycl/include/sycl/__spirv/spirv_ops.hpp index 5800190f539a0..64e8d1123c297 100644 --- a/sycl/include/sycl/__spirv/spirv_ops.hpp +++ b/sycl/include/sycl/__spirv/spirv_ops.hpp @@ -899,6 +899,20 @@ template __SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT __spirv_GroupNonUniformShuffle(__spv::Scope::Flag, ValueT, IdT) noexcept; +template +__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT + __spirv_GroupNonUniformShuffleXor(__spv::Scope::Flag, ValueT, IdT) noexcept; + +template +__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT + __spirv_GroupNonUniformShuffleUp(__spv::Scope::Flag, ValueT, IdT) noexcept; + +template +__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL + __SYCL_EXPORT ValueT __spirv_GroupNonUniformShuffleDown(__spv::Scope::Flag, + ValueT, + IdT) noexcept; + __SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT bool __spirv_GroupNonUniformAll(__spv::Scope::Flag, bool); diff --git a/sycl/include/sycl/detail/spirv.hpp b/sycl/include/sycl/detail/spirv.hpp index ad59d129622b9..8b13811245a90 100644 --- a/sycl/include/sycl/detail/spirv.hpp +++ b/sycl/include/sycl/detail/spirv.hpp @@ -796,30 +796,19 @@ AtomicMax(multi_ptr MPtr, memory_scope Scope, // variants for all scalar types #ifndef __NVPTX__ -template -struct TypeIsProhibitedForShuffleEmulation - : std::bool_constant< - check_type_in_v, double, long, long long, - unsigned long, unsigned long long, half>> {}; - -template -struct VecTypeIsProhibitedForShuffleEmulation - : std::bool_constant< - (detail::get_vec_size::size > 1) && - TypeIsProhibitedForShuffleEmulation>::value> {}; - +// Note: Although SPIR-V supports vector shuffles, the OpenCL specification only +// allow scalars in the operations. As such, we scalarize those too, then +// expect vectorization from the device compiler if possible. +// https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_cl_khr_subgroup_shuffle template using EnableIfNativeShuffle = std::enable_if_t::value && - !VecTypeIsProhibitedForShuffleEmulation::value && - !detail::is_marray_v, + !detail::is_marray_v && !detail::is_vec_v, T>; template using EnableIfNonScalarShuffle = - std::enable_if_t::value || - detail::is_marray_v, - T>; + std::enable_if_t || detail::is_vec_v, T>; #else // ifndef __NVPTX__ @@ -924,23 +913,8 @@ EnableIfNativeShuffle Shuffle(GroupT g, T x, id<1> local_id) { uint32_t LocalId = MapShuffleID(g, local_id); #ifndef __NVPTX__ std::ignore = g; - if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< - GroupT> && - detail::is_vec::value) { - // Temporary work-around due to a bug in IGC. - // TODO: Remove when IGC bug is fixed. - T result; - for (int s = 0; s < x.size(); ++s) - result[s] = Shuffle(g, x[s], local_id); - return result; - } else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< - GroupT>) { - return __spirv_GroupNonUniformShuffle(group_scope::value, - convertToOpenCLType(x), LocalId); - } else { - // Subgroup. - return __spirv_SubgroupShuffleINTEL(convertToOpenCLType(x), LocalId); - } + return __spirv_GroupNonUniformShuffle(group_scope::value, + convertToOpenCLType(x), LocalId); #else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< GroupT>) { @@ -957,16 +931,7 @@ EnableIfNativeShuffle ShuffleXor(GroupT g, T x, id<1> mask) { #ifndef __NVPTX__ std::ignore = g; if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< - GroupT> && - detail::is_vec::value) { - // Temporary work-around due to a bug in IGC. - // TODO: Remove when IGC bug is fixed. - T result; - for (int s = 0; s < x.size(); ++s) - result[s] = ShuffleXor(g, x[s], mask); - return result; - } else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< - GroupT>) { + GroupT>) { // Since the masks are relative to the groups, we could either try to adjust // the mask or simply do the xor ourselves. Latter option is efficient, // general, and simple so we go with that. @@ -976,8 +941,9 @@ EnableIfNativeShuffle ShuffleXor(GroupT g, T x, id<1> mask) { convertToOpenCLType(x), TargetId); } else { // Subgroup. - return __spirv_SubgroupShuffleXorINTEL(convertToOpenCLType(x), - static_cast(mask.get(0))); + return __spirv_GroupNonUniformShuffleXor( + __spv::Scope::Subgroup, convertToOpenCLType(x), + static_cast(mask.get(0))); } #else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< @@ -1004,16 +970,7 @@ template EnableIfNativeShuffle ShuffleDown(GroupT g, T x, uint32_t delta) { #ifndef __NVPTX__ if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< - GroupT> && - detail::is_vec::value) { - // Temporary work-around due to a bug in IGC. - // TODO: Remove when IGC bug is fixed. - T result; - for (int s = 0; s < x.size(); ++s) - result[s] = ShuffleDown(g, x[s], delta); - return result; - } else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< - GroupT>) { + GroupT>) { id<1> TargetLocalId = g.get_local_id(); // ID outside the group range is UB, so we just keep the current item ID // unchanged. @@ -1024,8 +981,8 @@ EnableIfNativeShuffle ShuffleDown(GroupT g, T x, uint32_t delta) { convertToOpenCLType(x), TargetId); } else { // Subgroup. - return __spirv_SubgroupShuffleDownINTEL(convertToOpenCLType(x), - convertToOpenCLType(x), delta); + return __spirv_GroupNonUniformShuffleDown(__spv::Scope::Subgroup, + convertToOpenCLType(x), delta); } #else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< @@ -1049,16 +1006,7 @@ template EnableIfNativeShuffle ShuffleUp(GroupT g, T x, uint32_t delta) { #ifndef __NVPTX__ if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< - GroupT> && - detail::is_vec::value) { - // Temporary work-around due to a bug in IGC. - // TODO: Remove when IGC bug is fixed. - T result; - for (int s = 0; s < x.size(); ++s) - result[s] = ShuffleUp(g, x[s], delta); - return result; - } else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< - GroupT>) { + GroupT>) { id<1> TargetLocalId = g.get_local_id(); // Underflow is UB, so we just keep the current item ID unchanged. if (TargetLocalId[0] >= delta) @@ -1068,8 +1016,8 @@ EnableIfNativeShuffle ShuffleUp(GroupT g, T x, uint32_t delta) { convertToOpenCLType(x), TargetId); } else { // Subgroup. - return __spirv_SubgroupShuffleUpINTEL(convertToOpenCLType(x), - convertToOpenCLType(x), delta); + return __spirv_GroupNonUniformShuffleUp(__spv::Scope::Subgroup, + convertToOpenCLType(x), delta); } #else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v< diff --git a/sycl/include/syclcompat/util.hpp b/sycl/include/syclcompat/util.hpp index df03599ea6ad0..872421f0d4d4b 100644 --- a/sycl/include/syclcompat/util.hpp +++ b/sycl/include/syclcompat/util.hpp @@ -46,25 +46,6 @@ #include #endif -// TODO: Remove these function definitions once they exist in the DPC++ compiler -#if defined(__SYCL_DEVICE_ONLY__) && defined(__INTEL_LLVM_COMPILER) -template -__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT - __attribute__((noduplicate)) T - __spirv_GroupNonUniformShuffle(__spv::Scope::Flag, T, unsigned) noexcept; - -template -__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT - __attribute__((noduplicate)) T - __spirv_GroupNonUniformShuffleDown(__spv::Scope::Flag, T, - unsigned) noexcept; - -template -__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT - __attribute__((noduplicate)) T - __spirv_GroupNonUniformShuffleUp(__spv::Scope::Flag, T, unsigned) noexcept; -#endif - namespace syclcompat { namespace detail {