Skip to content

Commit d36cdc5

Browse files
Refactor cub::BlockLoadand cub::BlockStore (#8120)
1 parent e337fa4 commit d36cdc5

File tree

3 files changed

+198
-651
lines changed

3 files changed

+198
-651
lines changed

cub/cub/block/block_load.cuh

Lines changed: 92 additions & 244 deletions
Original file line numberDiff line numberDiff line change
@@ -848,253 +848,34 @@ class BlockLoad
848848
{
849849
static constexpr int BlockThreads = BlockDimX * BlockDimY * BlockDimZ; // total threads in the block
850850

851-
template <BlockLoadAlgorithm _POLICY, int Dummy>
852-
struct LoadInternal; // helper to dispatch the load algorithm
853-
854-
template <int Dummy>
855-
struct LoadInternal<BLOCK_LOAD_DIRECT, Dummy>
851+
// transposing load algorithms need a BlockExchange
852+
using block_exchange =
853+
BlockExchange<T,
854+
BlockDimX,
855+
ItemsPerThread,
856+
/* WarpTimeSlicing = */ Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED,
857+
BlockDimY,
858+
BlockDimZ>;
859+
860+
static_assert((Algorithm != BLOCK_LOAD_WARP_TRANSPOSE && Algorithm != BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
861+
|| (BlockThreads % detail::warp_threads == 0),
862+
"BlockThreads must be a multiple of warp_threads for this BlockLoadAlgorithm");
863+
864+
_CCCL_API static constexpr auto temp_storage_helper()
856865
{
857-
using TempStorage = NullType;
858-
int linear_tid;
859-
860-
_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& /*temp_storage*/, int linear_tid)
861-
: linear_tid(linear_tid)
862-
{}
863-
864-
template <typename RandomAccessIterator>
865-
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
866-
{
867-
LoadDirectBlocked(linear_tid, block_src_it, dst_items);
868-
}
869-
870-
template <typename RandomAccessIterator>
871-
_CCCL_DEVICE _CCCL_FORCEINLINE void
872-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
873-
{
874-
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end);
875-
}
876-
877-
template <typename RandomAccessIterator, typename DefaultT>
878-
_CCCL_DEVICE _CCCL_FORCEINLINE void
879-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
866+
if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_STRIPED
867+
|| Algorithm == BLOCK_LOAD_VECTORIZE)
880868
{
881-
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
869+
return NullType{};
882870
}
883-
};
884-
885-
template <int Dummy>
886-
struct LoadInternal<BLOCK_LOAD_STRIPED, Dummy>
887-
{
888-
using TempStorage = NullType;
889-
int linear_tid;
890-
891-
_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& /*temp_storage*/, int linear_tid)
892-
: linear_tid(linear_tid)
893-
{}
894-
895-
template <typename RandomAccessIterator>
896-
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
871+
else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE
872+
|| Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
897873
{
898-
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
874+
return typename block_exchange::TempStorage{};
899875
}
876+
}
900877

901-
template <typename RandomAccessIterator>
902-
_CCCL_DEVICE _CCCL_FORCEINLINE void
903-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
904-
{
905-
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
906-
}
907-
908-
template <typename RandomAccessIterator, typename DefaultT>
909-
_CCCL_DEVICE _CCCL_FORCEINLINE void
910-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
911-
{
912-
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
913-
}
914-
};
915-
916-
template <int Dummy>
917-
struct LoadInternal<BLOCK_LOAD_VECTORIZE, Dummy>
918-
{
919-
using TempStorage = NullType;
920-
int linear_tid;
921-
922-
_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& /*temp_storage*/, int linear_tid)
923-
: linear_tid(linear_tid)
924-
{}
925-
926-
// attempts vectorization (pointer)
927-
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(const T* block_ptr, T (&dst_items)[ItemsPerThread])
928-
{
929-
InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_ptr, dst_items);
930-
}
931-
// NOTE: This function is necessary for pointers to non-const types.
932-
// The core reason is that the compiler will not deduce 'T*' to 'const T*' automatically.
933-
// Otherwise, when the pointer type is 'T*', the compiler will prefer the overloaded version
934-
// Load(RandomAccessIterator...) over Load(const T*...), which means it will never perform vectorized loading for
935-
// pointers to non-const types.
936-
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(T* block_ptr, T (&dst_items)[ItemsPerThread])
937-
{
938-
InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_ptr, dst_items);
939-
}
940-
941-
// any other iterator, no vectorization
942-
template <typename RandomAccessIterator>
943-
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
944-
{
945-
LoadDirectBlocked(linear_tid, block_src_it, dst_items);
946-
}
947-
948-
// attempts vectorization (cache modified iterator)
949-
template <CacheLoadModifier MODIFIER, typename ValueType, typename OffsetT>
950-
_CCCL_DEVICE _CCCL_FORCEINLINE void
951-
Load(CacheModifiedInputIterator<MODIFIER, ValueType, OffsetT> block_src_it, T (&dst_items)[ItemsPerThread])
952-
{
953-
InternalLoadDirectBlockedVectorized<MODIFIER>(linear_tid, block_src_it.ptr, dst_items);
954-
}
955-
956-
// skips vectorization
957-
template <typename RandomAccessIterator>
958-
_CCCL_DEVICE _CCCL_FORCEINLINE void
959-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
960-
{
961-
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end);
962-
}
963-
964-
// skips vectorization
965-
template <typename RandomAccessIterator, typename DefaultT>
966-
_CCCL_DEVICE _CCCL_FORCEINLINE void
967-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
968-
{
969-
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
970-
}
971-
};
972-
973-
template <int Dummy>
974-
struct LoadInternal<BLOCK_LOAD_TRANSPOSE, Dummy>
975-
{
976-
using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, false, BlockDimY, BlockDimZ>;
977-
using _TempStorage = typename BlockExchange::TempStorage;
978-
using TempStorage = Uninitialized<_TempStorage>;
979-
980-
_TempStorage& temp_storage;
981-
int linear_tid;
982-
983-
_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& temp_storage, int linear_tid)
984-
: temp_storage(temp_storage.Alias())
985-
, linear_tid(linear_tid)
986-
{}
987-
988-
template <typename RandomAccessIterator>
989-
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
990-
{
991-
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
992-
BlockExchange(temp_storage).StripedToBlocked(dst_items, dst_items);
993-
}
994-
995-
template <typename RandomAccessIterator>
996-
_CCCL_DEVICE _CCCL_FORCEINLINE void
997-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
998-
{
999-
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
1000-
BlockExchange(temp_storage).StripedToBlocked(dst_items, dst_items);
1001-
}
1002-
1003-
template <typename RandomAccessIterator, typename DefaultT>
1004-
_CCCL_DEVICE _CCCL_FORCEINLINE void
1005-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
1006-
{
1007-
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1008-
BlockExchange(temp_storage).StripedToBlocked(dst_items, dst_items);
1009-
}
1010-
};
1011-
1012-
template <int Dummy>
1013-
struct LoadInternal<BLOCK_LOAD_WARP_TRANSPOSE, Dummy>
1014-
{
1015-
static constexpr int WARP_THREADS = detail::warp_threads;
1016-
static_assert(BlockThreads % WARP_THREADS == 0, "BlockThreads must be a multiple of WARP_THREADS");
1017-
1018-
using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, false, BlockDimY, BlockDimZ>;
1019-
using _TempStorage = typename BlockExchange::TempStorage;
1020-
using TempStorage = Uninitialized<_TempStorage>;
1021-
1022-
_TempStorage& temp_storage;
1023-
int linear_tid;
1024-
1025-
_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& temp_storage, int linear_tid)
1026-
: temp_storage(temp_storage.Alias())
1027-
, linear_tid(linear_tid)
1028-
{}
1029-
1030-
template <typename RandomAccessIterator>
1031-
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
1032-
{
1033-
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items);
1034-
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1035-
}
1036-
1037-
template <typename RandomAccessIterator>
1038-
_CCCL_DEVICE _CCCL_FORCEINLINE void
1039-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
1040-
{
1041-
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end);
1042-
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1043-
}
1044-
1045-
template <typename RandomAccessIterator, typename DefaultT>
1046-
_CCCL_DEVICE _CCCL_FORCEINLINE void
1047-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
1048-
{
1049-
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1050-
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1051-
}
1052-
};
1053-
1054-
template <int Dummy>
1055-
struct LoadInternal<BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED, Dummy>
1056-
{
1057-
static constexpr int WARP_THREADS = detail::warp_threads;
1058-
static_assert(BlockThreads % WARP_THREADS == 0, "BlockThreads must be a multiple of WARP_THREADS");
1059-
1060-
using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, true, BlockDimY, BlockDimZ>;
1061-
using _TempStorage = typename BlockExchange::TempStorage;
1062-
using TempStorage = Uninitialized<_TempStorage>;
1063-
1064-
_TempStorage& temp_storage;
1065-
int linear_tid;
1066-
1067-
_CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal(TempStorage& temp_storage, int linear_tid)
1068-
: temp_storage(temp_storage.Alias())
1069-
, linear_tid(linear_tid)
1070-
{}
1071-
1072-
template <typename RandomAccessIterator>
1073-
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
1074-
{
1075-
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items);
1076-
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1077-
}
1078-
1079-
template <typename RandomAccessIterator>
1080-
_CCCL_DEVICE _CCCL_FORCEINLINE void
1081-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
1082-
{
1083-
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end);
1084-
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1085-
}
1086-
1087-
template <typename RandomAccessIterator, typename DefaultT>
1088-
_CCCL_DEVICE _CCCL_FORCEINLINE void
1089-
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
1090-
{
1091-
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1092-
BlockExchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1093-
}
1094-
};
1095-
1096-
using InternalLoad = LoadInternal<Algorithm, 0>; // load implementation to use
1097-
using _TempStorage = typename InternalLoad::TempStorage;
878+
using _TempStorage = decltype(temp_storage_helper());
1098879

1099880
// Internal storage allocator
1100881
_CCCL_DEVICE _CCCL_FORCEINLINE _TempStorage& PrivateStorage()
@@ -1188,7 +969,40 @@ public:
1188969
template <typename RandomAccessIterator>
1189970
_CCCL_DEVICE _CCCL_FORCEINLINE void Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
1190971
{
1191-
InternalLoad(temp_storage, linear_tid).Load(block_src_it, dst_items);
972+
if constexpr (Algorithm == BLOCK_LOAD_DIRECT)
973+
{
974+
LoadDirectBlocked(linear_tid, block_src_it, dst_items);
975+
}
976+
else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
977+
{
978+
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
979+
}
980+
else if constexpr (Algorithm == BLOCK_LOAD_VECTORIZE)
981+
{
982+
if constexpr (detail::is_CacheModifiedInputIterator<RandomAccessIterator>)
983+
{
984+
InternalLoadDirectBlockedVectorized<RandomAccessIterator::__modifier>(linear_tid, block_src_it.ptr, dst_items);
985+
}
986+
// FIXME(bgruber): we should test for contiguous iterator here
987+
else if constexpr (::cuda::std::is_pointer_v<RandomAccessIterator>)
988+
{
989+
InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_src_it, dst_items);
990+
}
991+
else
992+
{
993+
LoadDirectBlocked(linear_tid, block_src_it, dst_items);
994+
}
995+
}
996+
else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
997+
{
998+
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
999+
block_exchange(temp_storage).StripedToBlocked(dst_items, dst_items);
1000+
}
1001+
else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
1002+
{
1003+
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items);
1004+
block_exchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1005+
}
11921006
}
11931007

11941008
//! @rst
@@ -1244,7 +1058,24 @@ public:
12441058
_CCCL_DEVICE _CCCL_FORCEINLINE void
12451059
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
12461060
{
1247-
InternalLoad(temp_storage, linear_tid).Load(block_src_it, dst_items, block_items_end);
1061+
if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_VECTORIZE)
1062+
{
1063+
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end);
1064+
}
1065+
else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
1066+
{
1067+
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
1068+
}
1069+
else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
1070+
{
1071+
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
1072+
block_exchange(temp_storage).StripedToBlocked(dst_items, dst_items);
1073+
}
1074+
else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
1075+
{
1076+
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end);
1077+
block_exchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1078+
}
12481079
}
12491080

12501081
//! @rst
@@ -1303,7 +1134,24 @@ public:
13031134
_CCCL_DEVICE _CCCL_FORCEINLINE void
13041135
Load(RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
13051136
{
1306-
InternalLoad(temp_storage, linear_tid).Load(block_src_it, dst_items, block_items_end, oob_default);
1137+
if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_VECTORIZE)
1138+
{
1139+
LoadDirectBlocked(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1140+
}
1141+
else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
1142+
{
1143+
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1144+
}
1145+
else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
1146+
{
1147+
LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1148+
block_exchange(temp_storage).StripedToBlocked(dst_items, dst_items);
1149+
}
1150+
else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
1151+
{
1152+
LoadDirectWarpStriped(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1153+
block_exchange(temp_storage).WarpStripedToBlocked(dst_items, dst_items);
1154+
}
13071155
}
13081156

13091157
//! @}

0 commit comments

Comments
 (0)