Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 92 additions & 244 deletions cub/cub/block/block_load.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -848,253 +848,34 @@ class BlockLoad
{
static constexpr int BlockThreads = BlockDimX * BlockDimY * BlockDimZ; // total threads in the block

template <BlockLoadAlgorithm _POLICY, int Dummy>
struct LoadInternal; // helper to dispatch the load algorithm

template <int Dummy>
struct LoadInternal<BLOCK_LOAD_DIRECT, Dummy>
// transposing load algorithms need a BlockExchange
using block_exchange =
BlockExchange<T,
BlockDimX,
ItemsPerThread,
/* WarpTimeSlicing = */ Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED,
BlockDimY,
BlockDimZ>;

static_assert((Algorithm != BLOCK_LOAD_WARP_TRANSPOSE && Algorithm != BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
|| (BlockThreads % detail::warp_threads == 0),
"BlockThreads must be a multiple of warp_threads for this BlockLoadAlgorithm");

_CCCL_API static constexpr auto temp_storage_helper()
{
using TempStorage = NullType;
int linear_tid;

_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& /*temp_storage*/, int linear_tid)
: linear_tid(linear_tid)
{}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items);
}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end);
}

template <typename RandomAccessIterator, typename DefaultT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_STRIPED
|| Algorithm == BLOCK_LOAD_VECTORIZE)
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
return NullType{};
}
};

template <int Dummy>
struct LoadInternal<BLOCK_LOAD_STRIPED, Dummy>
{
using TempStorage = NullType;
int linear_tid;

_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& /*temp_storage*/, int linear_tid)
: linear_tid(linear_tid)
{}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE
|| Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
return typename block_exchange::TempStorage{};
}
}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
}

template <typename RandomAccessIterator, typename DefaultT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
}
};

template <int Dummy>
struct LoadInternal<BLOCK_LOAD_VECTORIZE, Dummy>
{
using TempStorage = NullType;
int linear_tid;

_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& /*temp_storage*/, int linear_tid)
: linear_tid(linear_tid)
{}

// attempts vectorization (pointer)
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(const T* block_ptr, T (&dst_items)[ItemsPerThread])
{
InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_ptr, dst_items);
}
// NOTE: This function is necessary for pointers to non-const types.
// The core reason is that the compiler will not deduce 'T*' to 'const T*' automatically.
// Otherwise, when the pointer type is 'T*', the compiler will prefer the overloaded version
// Load(RandomAccessIterator...) over Load(const T*...), which means it will never perform vectorized loading for
// pointers to non-const types.
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(T* block_ptr, T (&dst_items)[ItemsPerThread])
{
InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_ptr, dst_items);
}

// any other iterator, no vectorization
template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items);
}

// attempts vectorization (cache modified iterator)
template <CacheLoadModifier MODIFIER, typename ValueType, typename OffsetT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(CacheModifiedInputIterator<MODIFIER, ValueType, OffsetT> block_src_it, T (&dst_items)[ItemsPerThread])
{
InternalLoadDirectBlockedVectorized<MODIFIER>(linear_tid, block_src_it.ptr, dst_items);
}

// skips vectorization
template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end);
}

// skips vectorization
template <typename RandomAccessIterator, typename DefaultT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
}
};

template <int Dummy>
struct LoadInternal<BLOCK_LOAD_TRANSPOSE, Dummy>
{
using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, false, BlockDimY, BlockDimZ>;
using _TempStorage = typename BlockExchange::TempStorage;
using TempStorage = Uninitialized<_TempStorage>;

_TempStorage& temp_storage;
int linear_tid;

_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& temp_storage, int linear_tid)
: temp_storage(temp_storage.Alias())
, linear_tid(linear_tid)
{}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
BlockExchange(temp_storage).StripedToBlocked(dst_items, dst_items);
}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
BlockExchange(temp_storage).StripedToBlocked(dst_items, dst_items);
}

template <typename RandomAccessIterator, typename DefaultT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
BlockExchange(temp_storage).StripedToBlocked(dst_items, dst_items);
}
};

template <int Dummy>
struct LoadInternal<BLOCK_LOAD_WARP_TRANSPOSE, Dummy>
{
static constexpr int WARP_THREADS = detail::warp_threads;
static_assert(BlockThreads % WARP_THREADS == 0, "BlockThreads must be a multiple of WARP_THREADS");

using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, false, BlockDimY, BlockDimZ>;
using _TempStorage = typename BlockExchange::TempStorage;
using TempStorage = Uninitialized<_TempStorage>;

_TempStorage& temp_storage;
int linear_tid;

_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& temp_storage, int linear_tid)
: temp_storage(temp_storage.Alias())
, linear_tid(linear_tid)
{}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items);
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end);
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}

template <typename RandomAccessIterator, typename DefaultT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}
};

template <int Dummy>
struct LoadInternal<BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED, Dummy>
{
static constexpr int WARP_THREADS = detail::warp_threads;
static_assert(BlockThreads % WARP_THREADS == 0, "BlockThreads must be a multiple of WARP_THREADS");

using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, true, BlockDimY, BlockDimZ>;
using _TempStorage = typename BlockExchange::TempStorage;
using TempStorage = Uninitialized<_TempStorage>;

_TempStorage& temp_storage;
int linear_tid;

_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& temp_storage, int linear_tid)
: temp_storage(temp_storage.Alias())
, linear_tid(linear_tid)
{}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items);
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}

template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end);
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}

template <typename RandomAccessIterator, typename DefaultT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}
};

using InternalLoad = LoadInternal<Algorithm, 0>; // load implementation to use
using _TempStorage = typename InternalLoad::TempStorage;
using _TempStorage = decltype(temp_storage_helper());

// Internal storage allocator
_CCCL_DEVICE _CCCL_FORCEINLINE _TempStorage& PrivateStorage()
Expand Down Expand Up @@ -1188,7 +969,40 @@ public:
template <typename RandomAccessIterator>
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
{
InternalLoad(temp_storage, linear_tid).Load(block_src_it, dst_items);
if constexpr (Algorithm == BLOCK_LOAD_DIRECT)
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items);
}
else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
}
else if constexpr (Algorithm == BLOCK_LOAD_VECTORIZE)
{
if constexpr (detail::is_CacheModifiedInputIterator<RandomAccessIterator>)
{
InternalLoadDirectBlockedVectorized<RandomAccessIterator::__modifier>(linear_tid, block_src_it.ptr, dst_items);
}
// FIXME(bgruber): we should test for contiguous iterator here
else if constexpr (::cuda::std::is_pointer_v<RandomAccessIterator>)
{
InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_src_it, dst_items);
}
else
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items);
}
}
else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
block_exchange(temp_storage).StripedToBlocked(dst_items, dst_items);
}
else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items);
block_exchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}
}

//! @rst
Expand Down Expand Up @@ -1244,7 +1058,24 @@ public:
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
{
InternalLoad(temp_storage, linear_tid).Load(block_src_it, dst_items, block_items_end);
if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_VECTORIZE)
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end);
}
else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
}
else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
block_exchange(temp_storage).StripedToBlocked(dst_items, dst_items);
}
else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end);
block_exchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}
}

//! @rst
Expand Down Expand Up @@ -1303,7 +1134,24 @@ public:
_CCCL_DEVICE _CCCL_FORCEINLINE void
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
{
InternalLoad(temp_storage, linear_tid).Load(block_src_it, dst_items, block_items_end, oob_default);
if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_VECTORIZE)
{
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
}
else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
}
else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
{
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
block_exchange(temp_storage).StripedToBlocked(dst_items, dst_items);
}
else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
{
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
block_exchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
}
}

//! @}
Expand Down
Loading
Loading