Skip to content
Open
14 changes: 14 additions & 0 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,20 @@ template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformShuffle(__spv::Scope::Flag, ValueT, IdT) noexcept;

template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformShuffleXor(__spv::Scope::Flag, ValueT, IdT) noexcept;

template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformShuffleUp(__spv::Scope::Flag, ValueT, IdT) noexcept;

template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
__SYCL_EXPORT ValueT __spirv_GroupNonUniformShuffleDown(__spv::Scope::Flag,
ValueT,
IdT) noexcept;

__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT bool
__spirv_GroupNonUniformAll(__spv::Scope::Flag, bool);

Expand Down
88 changes: 18 additions & 70 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,30 +796,19 @@ AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
// variants for all scalar types
#ifndef __NVPTX__

template <typename T>
struct TypeIsProhibitedForShuffleEmulation
: std::bool_constant<
check_type_in_v<vector_element_t<T>, double, long, long long,
unsigned long, unsigned long long, half>> {};

template <typename T>
struct VecTypeIsProhibitedForShuffleEmulation
: std::bool_constant<
(detail::get_vec_size<T>::size > 1) &&
TypeIsProhibitedForShuffleEmulation<vector_element_t<T>>::value> {};

// Note: Although SPIR-V supports vector shuffles, the OpenCL specification only
// allow scalars in the operations. As such, we scalarize those too, then
// expect vectorization from the device compiler if possible.
// https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_cl_khr_subgroup_shuffle
template <typename T>
using EnableIfNativeShuffle =
std::enable_if_t<detail::is_arithmetic<T>::value &&
!VecTypeIsProhibitedForShuffleEmulation<T>::value &&
!detail::is_marray_v<T>,
!detail::is_marray_v<T> && !detail::is_vec_v<T>,
T>;

template <typename T>
using EnableIfNonScalarShuffle =
std::enable_if_t<VecTypeIsProhibitedForShuffleEmulation<T>::value ||
detail::is_marray_v<T>,
T>;
std::enable_if_t<detail::is_marray_v<T> || detail::is_vec_v<T>, T>;

#else // ifndef __NVPTX__

Expand Down Expand Up @@ -924,23 +913,8 @@ EnableIfNativeShuffle<T> Shuffle(GroupT g, T x, id<1> local_id) {
uint32_t LocalId = MapShuffleID(g, local_id);
#ifndef __NVPTX__
std::ignore = g;
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT> &&
detail::is_vec<T>::value) {
// Temporary work-around due to a bug in IGC.
// TODO: Remove when IGC bug is fixed.
T result;
for (int s = 0; s < x.size(); ++s)
result[s] = Shuffle(g, x[s], local_id);
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
convertToOpenCLType(x), LocalId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleINTEL(convertToOpenCLType(x), LocalId);
}
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
convertToOpenCLType(x), LocalId);
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
Expand All @@ -957,16 +931,7 @@ EnableIfNativeShuffle<T> ShuffleXor(GroupT g, T x, id<1> mask) {
#ifndef __NVPTX__
std::ignore = g;
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT> &&
detail::is_vec<T>::value) {
// Temporary work-around due to a bug in IGC.
// TODO: Remove when IGC bug is fixed.
T result;
for (int s = 0; s < x.size(); ++s)
result[s] = ShuffleXor(g, x[s], mask);
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
GroupT>) {
// Since the masks are relative to the groups, we could either try to adjust
// the mask or simply do the xor ourselves. Latter option is efficient,
// general, and simple so we go with that.
Expand All @@ -976,8 +941,9 @@ EnableIfNativeShuffle<T> ShuffleXor(GroupT g, T x, id<1> mask) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleXorINTEL(convertToOpenCLType(x),
static_cast<uint32_t>(mask.get(0)));
return __spirv_GroupNonUniformShuffleXor(
__spv::Scope::Subgroup, convertToOpenCLType(x),
static_cast<uint32_t>(mask.get(0)));
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand All @@ -1004,16 +970,7 @@ template <typename GroupT, typename T>
EnableIfNativeShuffle<T> ShuffleDown(GroupT g, T x, uint32_t delta) {
#ifndef __NVPTX__
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT> &&
detail::is_vec<T>::value) {
// Temporary work-around due to a bug in IGC.
// TODO: Remove when IGC bug is fixed.
T result;
for (int s = 0; s < x.size(); ++s)
result[s] = ShuffleDown(g, x[s], delta);
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
GroupT>) {
id<1> TargetLocalId = g.get_local_id();
// ID outside the group range is UB, so we just keep the current item ID
// unchanged.
Expand All @@ -1024,8 +981,8 @@ EnableIfNativeShuffle<T> ShuffleDown(GroupT g, T x, uint32_t delta) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleDownINTEL(convertToOpenCLType(x),
convertToOpenCLType(x), delta);
return __spirv_GroupNonUniformShuffleDown(__spv::Scope::Subgroup,
convertToOpenCLType(x), delta);
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand All @@ -1049,16 +1006,7 @@ template <typename GroupT, typename T>
EnableIfNativeShuffle<T> ShuffleUp(GroupT g, T x, uint32_t delta) {
#ifndef __NVPTX__
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT> &&
detail::is_vec<T>::value) {
// Temporary work-around due to a bug in IGC.
// TODO: Remove when IGC bug is fixed.
T result;
for (int s = 0; s < x.size(); ++s)
result[s] = ShuffleUp(g, x[s], delta);
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
GroupT>) {
id<1> TargetLocalId = g.get_local_id();
// Underflow is UB, so we just keep the current item ID unchanged.
if (TargetLocalId[0] >= delta)
Expand All @@ -1068,8 +1016,8 @@ EnableIfNativeShuffle<T> ShuffleUp(GroupT g, T x, uint32_t delta) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleUpINTEL(convertToOpenCLType(x),
convertToOpenCLType(x), delta);
return __spirv_GroupNonUniformShuffleUp(__spv::Scope::Subgroup,
convertToOpenCLType(x), delta);
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand Down
19 changes: 0 additions & 19 deletions sycl/include/syclcompat/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@
#include <sycl/ext/oneapi/experimental/cuda/masked_shuffles.hpp>
#endif

// TODO: Remove these function definitions once they exist in the DPC++ compiler
#if defined(__SYCL_DEVICE_ONLY__) && defined(__INTEL_LLVM_COMPILER)
template <typename T>
__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT
__attribute__((noduplicate)) T
__spirv_GroupNonUniformShuffle(__spv::Scope::Flag, T, unsigned) noexcept;

template <typename T>
__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT
__attribute__((noduplicate)) T
__spirv_GroupNonUniformShuffleDown(__spv::Scope::Flag, T,
unsigned) noexcept;

template <typename T>
__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT
__attribute__((noduplicate)) T
__spirv_GroupNonUniformShuffleUp(__spv::Scope::Flag, T, unsigned) noexcept;
#endif

namespace syclcompat {

namespace detail {
Expand Down
Loading