@@ -848,253 +848,34 @@ class BlockLoad
848848{
849849 static constexpr int BlockThreads = BlockDimX * BlockDimY * BlockDimZ; // total threads in the block
850850
851- template <BlockLoadAlgorithm _POLICY, int Dummy>
852- struct LoadInternal ; // helper to dispatch the load algorithm
853-
854- template <int Dummy>
855- struct LoadInternal <BLOCK_LOAD_DIRECT, Dummy>
851+ // transposing load algorithms need a BlockExchange
852+ using block_exchange =
853+ BlockExchange<T,
854+ BlockDimX,
855+ ItemsPerThread,
856+ /* WarpTimeSlicing = */ Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED,
857+ BlockDimY,
858+ BlockDimZ>;
859+
860+ static_assert ((Algorithm != BLOCK_LOAD_WARP_TRANSPOSE && Algorithm != BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
861+ || (BlockThreads % detail::warp_threads == 0 ),
862+ " BlockThreads must be a multiple of warp_threads for this BlockLoadAlgorithm" );
863+
864+ _CCCL_API static constexpr auto temp_storage_helper ()
856865 {
857- using TempStorage = NullType;
858- int linear_tid;
859-
860- _CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal (TempStorage& /* temp_storage*/ , int linear_tid)
861- : linear_tid(linear_tid)
862- {}
863-
864- template <typename RandomAccessIterator>
865- _CCCL_DEVICE _CCCL_FORCEINLINE void Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
866- {
867- LoadDirectBlocked (linear_tid, block_src_it, dst_items);
868- }
869-
870- template <typename RandomAccessIterator>
871- _CCCL_DEVICE _CCCL_FORCEINLINE void
872- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
873- {
874- LoadDirectBlocked (linear_tid, block_src_it, dst_items, block_items_end);
875- }
876-
877- template <typename RandomAccessIterator, typename DefaultT>
878- _CCCL_DEVICE _CCCL_FORCEINLINE void
879- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
866+ if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_STRIPED
867+ || Algorithm == BLOCK_LOAD_VECTORIZE)
880868 {
881- LoadDirectBlocked (linear_tid, block_src_it, dst_items, block_items_end, oob_default) ;
869+ return NullType{} ;
882870 }
883- };
884-
885- template <int Dummy>
886- struct LoadInternal <BLOCK_LOAD_STRIPED, Dummy>
887- {
888- using TempStorage = NullType;
889- int linear_tid;
890-
891- _CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal (TempStorage& /* temp_storage*/ , int linear_tid)
892- : linear_tid(linear_tid)
893- {}
894-
895- template <typename RandomAccessIterator>
896- _CCCL_DEVICE _CCCL_FORCEINLINE void Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
871+ else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE
872+ || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
897873 {
898- LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items) ;
874+ return typename block_exchange::TempStorage{} ;
899875 }
876+ }
900877
901- template <typename RandomAccessIterator>
902- _CCCL_DEVICE _CCCL_FORCEINLINE void
903- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
904- {
905- LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
906- }
907-
908- template <typename RandomAccessIterator, typename DefaultT>
909- _CCCL_DEVICE _CCCL_FORCEINLINE void
910- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
911- {
912- LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
913- }
914- };
915-
916- template <int Dummy>
917- struct LoadInternal <BLOCK_LOAD_VECTORIZE, Dummy>
918- {
919- using TempStorage = NullType;
920- int linear_tid;
921-
922- _CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal (TempStorage& /* temp_storage*/ , int linear_tid)
923- : linear_tid(linear_tid)
924- {}
925-
926- // attempts vectorization (pointer)
927- _CCCL_DEVICE _CCCL_FORCEINLINE void Load (const T* block_ptr, T (&dst_items)[ItemsPerThread])
928- {
929- InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_ptr, dst_items);
930- }
931- // NOTE: This function is necessary for pointers to non-const types.
932- // The core reason is that the compiler will not deduce 'T*' to 'const T*' automatically.
933- // Otherwise, when the pointer type is 'T*', the compiler will prefer the overloaded version
934- // Load(RandomAccessIterator...) over Load(const T*...), which means it will never perform vectorized loading for
935- // pointers to non-const types.
936- _CCCL_DEVICE _CCCL_FORCEINLINE void Load (T* block_ptr, T (&dst_items)[ItemsPerThread])
937- {
938- InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_ptr, dst_items);
939- }
940-
941- // any other iterator, no vectorization
942- template <typename RandomAccessIterator>
943- _CCCL_DEVICE _CCCL_FORCEINLINE void Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
944- {
945- LoadDirectBlocked (linear_tid, block_src_it, dst_items);
946- }
947-
948- // attempts vectorization (cache modified iterator)
949- template <CacheLoadModifier MODIFIER, typename ValueType, typename OffsetT>
950- _CCCL_DEVICE _CCCL_FORCEINLINE void
951- Load (CacheModifiedInputIterator<MODIFIER, ValueType, OffsetT> block_src_it, T (&dst_items)[ItemsPerThread])
952- {
953- InternalLoadDirectBlockedVectorized<MODIFIER>(linear_tid, block_src_it.ptr , dst_items);
954- }
955-
956- // skips vectorization
957- template <typename RandomAccessIterator>
958- _CCCL_DEVICE _CCCL_FORCEINLINE void
959- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
960- {
961- LoadDirectBlocked (linear_tid, block_src_it, dst_items, block_items_end);
962- }
963-
964- // skips vectorization
965- template <typename RandomAccessIterator, typename DefaultT>
966- _CCCL_DEVICE _CCCL_FORCEINLINE void
967- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
968- {
969- LoadDirectBlocked (linear_tid, block_src_it, dst_items, block_items_end, oob_default);
970- }
971- };
972-
973- template <int Dummy>
974- struct LoadInternal <BLOCK_LOAD_TRANSPOSE, Dummy>
975- {
976- using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, false , BlockDimY, BlockDimZ>;
977- using _TempStorage = typename BlockExchange::TempStorage;
978- using TempStorage = Uninitialized<_TempStorage>;
979-
980- _TempStorage& temp_storage;
981- int linear_tid;
982-
983- _CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal (TempStorage& temp_storage, int linear_tid)
984- : temp_storage(temp_storage.Alias())
985- , linear_tid(linear_tid)
986- {}
987-
988- template <typename RandomAccessIterator>
989- _CCCL_DEVICE _CCCL_FORCEINLINE void Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
990- {
991- LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
992- BlockExchange (temp_storage).StripedToBlocked (dst_items, dst_items);
993- }
994-
995- template <typename RandomAccessIterator>
996- _CCCL_DEVICE _CCCL_FORCEINLINE void
997- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
998- {
999- LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
1000- BlockExchange (temp_storage).StripedToBlocked (dst_items, dst_items);
1001- }
1002-
1003- template <typename RandomAccessIterator, typename DefaultT>
1004- _CCCL_DEVICE _CCCL_FORCEINLINE void
1005- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
1006- {
1007- LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1008- BlockExchange (temp_storage).StripedToBlocked (dst_items, dst_items);
1009- }
1010- };
1011-
1012- template <int Dummy>
1013- struct LoadInternal <BLOCK_LOAD_WARP_TRANSPOSE, Dummy>
1014- {
1015- static constexpr int WARP_THREADS = detail::warp_threads;
1016- static_assert (BlockThreads % WARP_THREADS == 0 , " BlockThreads must be a multiple of WARP_THREADS" );
1017-
1018- using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, false , BlockDimY, BlockDimZ>;
1019- using _TempStorage = typename BlockExchange::TempStorage;
1020- using TempStorage = Uninitialized<_TempStorage>;
1021-
1022- _TempStorage& temp_storage;
1023- int linear_tid;
1024-
1025- _CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal (TempStorage& temp_storage, int linear_tid)
1026- : temp_storage(temp_storage.Alias())
1027- , linear_tid(linear_tid)
1028- {}
1029-
1030- template <typename RandomAccessIterator>
1031- _CCCL_DEVICE _CCCL_FORCEINLINE void Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
1032- {
1033- LoadDirectWarpStriped (linear_tid, block_src_it, dst_items);
1034- BlockExchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1035- }
1036-
1037- template <typename RandomAccessIterator>
1038- _CCCL_DEVICE _CCCL_FORCEINLINE void
1039- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
1040- {
1041- LoadDirectWarpStriped (linear_tid, block_src_it, dst_items, block_items_end);
1042- BlockExchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1043- }
1044-
1045- template <typename RandomAccessIterator, typename DefaultT>
1046- _CCCL_DEVICE _CCCL_FORCEINLINE void
1047- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
1048- {
1049- LoadDirectWarpStriped (linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1050- BlockExchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1051- }
1052- };
1053-
1054- template <int Dummy>
1055- struct LoadInternal <BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED, Dummy>
1056- {
1057- static constexpr int WARP_THREADS = detail::warp_threads;
1058- static_assert (BlockThreads % WARP_THREADS == 0 , " BlockThreads must be a multiple of WARP_THREADS" );
1059-
1060- using BlockExchange = BlockExchange<T, BlockDimX, ItemsPerThread, true , BlockDimY, BlockDimZ>;
1061- using _TempStorage = typename BlockExchange::TempStorage;
1062- using TempStorage = Uninitialized<_TempStorage>;
1063-
1064- _TempStorage& temp_storage;
1065- int linear_tid;
1066-
1067- _CCCL_DEVICE _CCCL_FORCEINLINE LoadInternal (TempStorage& temp_storage, int linear_tid)
1068- : temp_storage(temp_storage.Alias())
1069- , linear_tid(linear_tid)
1070- {}
1071-
1072- template <typename RandomAccessIterator>
1073- _CCCL_DEVICE _CCCL_FORCEINLINE void Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
1074- {
1075- LoadDirectWarpStriped (linear_tid, block_src_it, dst_items);
1076- BlockExchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1077- }
1078-
1079- template <typename RandomAccessIterator>
1080- _CCCL_DEVICE _CCCL_FORCEINLINE void
1081- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
1082- {
1083- LoadDirectWarpStriped (linear_tid, block_src_it, dst_items, block_items_end);
1084- BlockExchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1085- }
1086-
1087- template <typename RandomAccessIterator, typename DefaultT>
1088- _CCCL_DEVICE _CCCL_FORCEINLINE void
1089- Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
1090- {
1091- LoadDirectWarpStriped (linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1092- BlockExchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1093- }
1094- };
1095-
1096- using InternalLoad = LoadInternal<Algorithm, 0 >; // load implementation to use
1097- using _TempStorage = typename InternalLoad::TempStorage;
878+ using _TempStorage = decltype (temp_storage_helper());
1098879
1099880 // Internal storage allocator
1100881 _CCCL_DEVICE _CCCL_FORCEINLINE _TempStorage& PrivateStorage ()
@@ -1188,7 +969,40 @@ public:
1188969 template <typename RandomAccessIterator>
1189970 _CCCL_DEVICE _CCCL_FORCEINLINE void Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread])
1190971 {
1191- InternalLoad (temp_storage, linear_tid).Load (block_src_it, dst_items);
972+ if constexpr (Algorithm == BLOCK_LOAD_DIRECT)
973+ {
974+ LoadDirectBlocked (linear_tid, block_src_it, dst_items);
975+ }
976+ else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
977+ {
978+ LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
979+ }
980+ else if constexpr (Algorithm == BLOCK_LOAD_VECTORIZE)
981+ {
982+ if constexpr (detail::is_CacheModifiedInputIterator<RandomAccessIterator>)
983+ {
984+ InternalLoadDirectBlockedVectorized<RandomAccessIterator::__modifier>(linear_tid, block_src_it.ptr , dst_items);
985+ }
986+ // FIXME(bgruber): we should test for contiguous iterator here
987+ else if constexpr (::cuda::std::is_pointer_v<RandomAccessIterator>)
988+ {
989+ InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_src_it, dst_items);
990+ }
991+ else
992+ {
993+ LoadDirectBlocked (linear_tid, block_src_it, dst_items);
994+ }
995+ }
996+ else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
997+ {
998+ LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items);
999+ block_exchange (temp_storage).StripedToBlocked (dst_items, dst_items);
1000+ }
1001+ else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
1002+ {
1003+ LoadDirectWarpStriped (linear_tid, block_src_it, dst_items);
1004+ block_exchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1005+ }
11921006 }
11931007
11941008 // ! @rst
@@ -1244,7 +1058,24 @@ public:
12441058 _CCCL_DEVICE _CCCL_FORCEINLINE void
12451059 Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end)
12461060 {
1247- InternalLoad (temp_storage, linear_tid).Load (block_src_it, dst_items, block_items_end);
1061+ if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_VECTORIZE)
1062+ {
1063+ LoadDirectBlocked (linear_tid, block_src_it, dst_items, block_items_end);
1064+ }
1065+ else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
1066+ {
1067+ LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
1068+ }
1069+ else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
1070+ {
1071+ LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end);
1072+ block_exchange (temp_storage).StripedToBlocked (dst_items, dst_items);
1073+ }
1074+ else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
1075+ {
1076+ LoadDirectWarpStriped (linear_tid, block_src_it, dst_items, block_items_end);
1077+ block_exchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1078+ }
12481079 }
12491080
12501081 // ! @rst
@@ -1303,7 +1134,24 @@ public:
13031134 _CCCL_DEVICE _CCCL_FORCEINLINE void
13041135 Load (RandomAccessIterator block_src_it, T (&dst_items)[ItemsPerThread], int block_items_end, DefaultT oob_default)
13051136 {
1306- InternalLoad (temp_storage, linear_tid).Load (block_src_it, dst_items, block_items_end, oob_default);
1137+ if constexpr (Algorithm == BLOCK_LOAD_DIRECT || Algorithm == BLOCK_LOAD_VECTORIZE)
1138+ {
1139+ LoadDirectBlocked (linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1140+ }
1141+ else if constexpr (Algorithm == BLOCK_LOAD_STRIPED)
1142+ {
1143+ LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1144+ }
1145+ else if constexpr (Algorithm == BLOCK_LOAD_TRANSPOSE)
1146+ {
1147+ LoadDirectStriped<BlockThreads>(linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1148+ block_exchange (temp_storage).StripedToBlocked (dst_items, dst_items);
1149+ }
1150+ else if constexpr (Algorithm == BLOCK_LOAD_WARP_TRANSPOSE || Algorithm == BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED)
1151+ {
1152+ LoadDirectWarpStriped (linear_tid, block_src_it, dst_items, block_items_end, oob_default);
1153+ block_exchange (temp_storage).WarpStripedToBlocked (dst_items, dst_items);
1154+ }
13071155 }
13081156
13091157 // ! @}
0 commit comments