|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +// SPDX-License-Identifier: BSD-3 |
| 3 | + |
| 4 | +//! @file |
| 5 | +//! @rst |
| 6 | +//! The ``cub::WarpReduceBatched`` class provides :ref:`collective <collective-primitives>` methods for |
| 7 | +//! performing batched parallel reductions of multiple arrays partitioned across a CUDA thread warp. |
| 8 | +//! @endrst |
| 9 | + |
| 10 | +#pragma once |
| 11 | + |
| 12 | +#include <cub/config.cuh> |
| 13 | + |
| 14 | +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) |
| 15 | +# pragma GCC system_header |
| 16 | +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) |
| 17 | +# pragma clang system_header |
| 18 | +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) |
| 19 | +# pragma system_header |
| 20 | +#endif // no system header |
| 21 | + |
| 22 | +#include <cub/detail/type_traits.cuh> |
| 23 | +#include <cub/thread/thread_operators.cuh> |
| 24 | +#include <cub/util_type.cuh> |
| 25 | + |
| 26 | +#include <cuda/__ptx/instructions/get_sreg.h> |
| 27 | +#include <cuda/bit> |
| 28 | +#include <cuda/cmath> |
| 29 | +#include <cuda/std/__functional/operations.h> |
| 30 | +#include <cuda/std/__iterator/iterator_traits.h> |
| 31 | +#include <cuda/warp> |
| 32 | + |
| 33 | +CUB_NAMESPACE_BEGIN |
| 34 | + |
| 35 | +//! @rst |
| 36 | +//! The ``WarpReduceBatched`` class provides :ref:`collective <collective-primitives>` methods for performing |
| 37 | +//! batched parallel reductions of multiple arrays partitioned across a CUDA thread warp using the WSPRO algorithm. |
| 38 | +//! |
| 39 | +//! Overview |
| 40 | +//! ++++++++ |
| 41 | +//! |
| 42 | +//! - Performs batched reductions of Batches arrays, each containing LogicalWarpThreads elements |
| 43 | +//! - Completes in Batches-1 stages using the WSPRO (Warp Shuffle Parallel Reduction Optimization) algorithm |
| 44 | +//! - Standard approach (Batches sequential calls to WarpReduce) requires ``Batches * log2(LogicalWarpThreads)`` stages |
| 45 | +//! - Best performance when ``Batches == LogicalWarpThreads`` (both powers of 2) |
| 46 | +//! - **Output semantics:** Thread i's ``outputs`` array contains: |
| 47 | +//! - ``outputs[0]`` = reduction of array i |
| 48 | +//! - ``outputs[1]`` = reduction of array (i + LogicalWarpThreads), if it exists |
| 49 | +//! - ``outputs[k]`` = reduction of array (i + k * LogicalWarpThreads) |
| 50 | +//! |
| 51 | +//! Performance Characteristics |
| 52 | +//! +++++++++++++++++++++++++++ |
| 53 | +//! |
| 54 | +//! - **Stage count:** Batches-1 stages vs Batches * log2(LogicalWarpThreads) for sequential WarpReduce calls |
| 55 | +//! - **Example (Batches=8, LogicalWarpThreads=8):** 7 stages (batched) vs 24 stages (sequential) = 3.4x reduction |
| 56 | +//! - **Example (Batches=16, LogicalWarpThreads=16):** 15 stages (batched) vs 64 stages (sequential) = 4.3x reduction |
| 57 | +//! - Uses warp ``SHFL`` instructions for communication |
| 58 | +//! - No shared memory required |
| 59 | +//! |
| 60 | +//! When to Use |
| 61 | +//! +++++++++++ |
| 62 | +//! |
| 63 | +//! - When you need to reduce multiple independent batches within a warp |
| 64 | +//! - When ``Batches`` and ``LogicalWarpThreads`` are powers of 2 (required for ``LogicalWarpThreads``, recommended for |
| 65 | +//! ``Batches``) |
| 66 | +//! - When computing many small reductions (e.g., per-pixel reductions) |
| 67 | +//! |
| 68 | +//! Simple Examples |
| 69 | +//! +++++++++++++++ |
| 70 | +//! |
| 71 | +//! @warpcollective{WarpReduceBatched} |
| 72 | +//! |
| 73 | +//! The code snippet below illustrates reduction of 8 batches of 8 elements each: |
| 74 | +//! |
| 75 | +//! .. code-block:: c++ |
| 76 | +//! |
| 77 | +//! #include <cub/cub.cuh> |
| 78 | +//! |
| 79 | +//! __global__ void ExampleKernel(...) |
| 80 | +//! { |
| 81 | +//! // Specialize WarpReduceBatched for 8 batches of 8 int elements |
| 82 | +//! using WarpReduceBatched = cub::WarpReduceBatched<int, 8, 8>; |
| 83 | +//! |
| 84 | +//! // Allocate shared memory (none needed for shuffle-based implementation) |
| 85 | +//! __shared__ typename WarpReduceBatched::TempStorage temp_storage; |
| 86 | +//! |
| 87 | +//! // Each thread provides 8 input values (one element from each of 8 arrays) |
| 88 | +//! int thread_data[8]; |
| 89 | +//! for (int i = 0; i < 8; i++) |
| 90 | +//! { |
| 91 | +//! thread_data[i] = ...; // Load element i from thread's position |
| 92 | +//! } |
| 93 | +//! |
| 94 | +//! // Perform batched reduction (7 stages vs 24 for sequential) |
| 95 | +//! int results[8]; |
| 96 | +//! WarpReduceBatched(temp_storage).ReduceBatched(thread_data, results, ::cuda::std::plus<>{}); |
| 97 | +//! |
| 98 | +//! // Thread i now has the sum of array i in results[i] |
| 99 | +//! } |
| 100 | +//! |
| 101 | +//! @endrst |
| 102 | +//! |
| 103 | +//! @tparam T |
| 104 | +//! The reduction input/output element type |
| 105 | +//! |
| 106 | +//! @tparam Batches |
| 107 | +//! The number of arrays to reduce in batch. Best performance when Batches = LogicalWarpThreads. |
| 108 | +//! |
| 109 | +//! @tparam LogicalWarpThreads |
| 110 | +//! The number of threads per logical warp / elements per array. Must be a power-of-two in range [2, 32]. |
| 111 | +//! Default is the warp size of the targeted CUDA compute-capability (e.g., 32). |
| 112 | +//! |
| 113 | +template <typename T, int Batches, int LogicalWarpThreads = detail::warp_threads> |
| 114 | +class WarpReduceBatched |
| 115 | +{ |
| 116 | + static_assert(::cuda::is_power_of_two(LogicalWarpThreads), "LogicalWarpThreads must be a power of two"); |
| 117 | + // TODO: Should we allow LogicalWarpThreads = 1? (in which case everything is just no-op/copy) |
| 118 | + static_assert(LogicalWarpThreads > 1 && LogicalWarpThreads <= detail::warp_threads, |
| 119 | + "LogicalWarpThreads must be in the range [2, 32]"); |
| 120 | + // TODO: Should we restrict to Batches > 1? |
| 121 | + static_assert(Batches >= 1, "Batches must be >= 1"); |
| 122 | + |
| 123 | +private: |
| 124 | + //--------------------------------------------------------------------- |
| 125 | + // Constants and type definitions |
| 126 | + //--------------------------------------------------------------------- |
| 127 | + |
| 128 | + /// Whether the logical warp size and the PTX warp size coincide |
| 129 | + static constexpr auto is_arch_warp = (LogicalWarpThreads == detail::warp_threads); |
| 130 | + |
| 131 | + static constexpr auto max_out_per_thread = ::cuda::ceil_div(Batches, LogicalWarpThreads); |
| 132 | + |
| 133 | + //--------------------------------------------------------------------- |
| 134 | + // Thread fields |
| 135 | + //--------------------------------------------------------------------- |
| 136 | + |
| 137 | + /// Lane index in logical warp |
| 138 | + int lane_id; |
| 139 | + |
| 140 | +public: |
| 141 | + //! @name Collective constructors |
| 142 | + //! @{ |
| 143 | + |
| 144 | + //! @rst |
| 145 | + //! Collective constructor using the specified memory allocation as temporary storage. |
| 146 | + //! Logical warp and lane identifiers are constructed from ``threadIdx.x``. |
| 147 | + //! @endrst |
| 148 | + _CCCL_DEVICE _CCCL_FORCEINLINE WarpReduceBatched() |
| 149 | + : lane_id(static_cast<int>(::cuda::ptx::get_sreg_laneid())) |
| 150 | + { |
| 151 | + if (!is_arch_warp) |
| 152 | + { |
| 153 | + lane_id = lane_id % LogicalWarpThreads; |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + //! @} |
| 158 | + //! @name Batched reductions |
| 159 | + //! @{ |
| 160 | + |
| 161 | + //! @rst |
| 162 | + //! Performs batched reduction of ``Batches`` arrays using the specified binary reduction operator. |
| 163 | + //! |
| 164 | + //! Each thread provides ``Batches`` input values (one element from each batch). |
| 165 | + //! The warp collectively reduces each of the ``Batches`` batches (each containing ``LogicalWarpThreads`` elements). |
| 166 | + //! Thread *i* stores results sequentially in its ``outputs`` array: |
| 167 | + //! ``outputs[0]`` = result of batch *i*, ``outputs[1]`` = result of batch *(i + LogicalWarpThreads)*, etc. |
| 168 | + //! |
| 169 | + //! **Algorithm Performance:** |
| 170 | + //! |
| 171 | + //! - Completes in ``Batches - 1 + log2(LogicalWarpThreads / Batches)`` stages |
| 172 | + //! - vs ``Batches * log2(LogicalWarpThreads)`` stages for ``Batches`` sequential ``WarpReduce`` calls |
| 173 | + //! - Example: ``Batches=8``, ``LogicalWarpThreads=8`` -> 7 stages (batched) vs 24 stages (sequential) |
| 174 | + //! |
| 175 | + //! Snippet |
| 176 | + //! +++++++ |
| 177 | + //! |
| 178 | + //! The code snippet below illustrates batched reduction of 8 batches: |
| 179 | + //! |
| 180 | + //! .. code-block:: c++ |
| 181 | + //! |
| 182 | + //! #include <cub/cub.cuh> |
| 183 | + //! |
| 184 | + //! __global__ void ExampleKernel(...) |
| 185 | + //! { |
| 186 | + //! using WarpReduceBatched = cub::WarpReduceBatched<int, 8, 8>; |
| 187 | + //! |
| 188 | + //! cuda::std::array<int, 8> inputs = {...}; // Each thread provides 8 values |
| 189 | + //! cuda::std::array<int, 1> output; |
| 190 | + //! |
| 191 | + //! WarpReduceBatched.Reduce( |
| 192 | + //! inputs, output, cuda::std::plus<>{}); |
| 193 | + //! |
| 194 | + //! // Logical warp lane i now has sum of batch i in output[0] |
| 195 | + //! } |
| 196 | + //! |
| 197 | + //! @endrst |
| 198 | + //! |
| 199 | + //! @tparam InputT |
| 200 | + //! **[inferred]** Input array-like type (C-array, cuda::std::array, cuda::std::span, etc.) |
| 201 | + //! |
| 202 | + //! @tparam OutputT |
| 203 | + //! **[inferred]** Output array-like type (C-array, cuda::std::array, cuda::std::span, etc.) |
| 204 | + //! |
| 205 | + //! @tparam ReductionOp |
| 206 | + //! **[inferred]** Binary reduction operator type having member |
| 207 | + //! `T operator()(const T &a, const T &b)` |
| 208 | + //! |
| 209 | + //! @param[in] inputs |
| 210 | + //! Statically-sized array-like container of Batches input values from calling thread |
| 211 | + //! |
| 212 | + //! @param[out] outputs |
| 213 | + //! Statically-sized array-like container where thread i stores reductions sequentially: |
| 214 | + //! ``outputs[0]`` = result of batch i, ``outputs[1]`` = result of batch (i + LogicalWarpThreads), etc. |
| 215 | + //! |
| 216 | + //! @param[in] reduction_op |
| 217 | + //! Binary reduction operator |
| 218 | + //! |
| 219 | + //! @param[in] lane_mask |
| 220 | + //! Lane mask to restrict the reduction to a subset of the logical warps present in the physical warp. |
| 221 | + //! Default is all logical warps. |
| 222 | + template <typename InputT, typename OutputT, typename ReductionOp> |
| 223 | + _CCCL_DEVICE _CCCL_FORCEINLINE void |
| 224 | + Reduce(const InputT& inputs, |
| 225 | + OutputT& outputs, |
| 226 | + ReductionOp reduction_op, |
| 227 | + ::cuda::std::uint32_t lane_mask = ::cuda::device::lane_mask::all().value()) |
| 228 | + { |
| 229 | + static_assert(::cub::detail::is_fixed_size_random_access_range_v<InputT>, |
| 230 | + "InputT must support the subscript operator[] and have a compile-time size"); |
| 231 | + static_assert(::cub::detail::is_fixed_size_random_access_range_v<OutputT>, |
| 232 | + "OutputT must support the subscript operator[] and have a compile-time size"); |
| 233 | + static_assert(::cub::detail::static_size_v<InputT> == Batches, "Input size must match Batches"); |
| 234 | + static_assert(::cub::detail::static_size_v<OutputT> == max_out_per_thread, |
| 235 | + "Output size must match ceil_div(Batches, LogicalWarpThreads)"); |
| 236 | + // These restrictions could be relaxed to allow type-conversions |
| 237 | + static_assert(::cuda::std::is_same_v<::cuda::std::iter_value_t<InputT>, T>, "Input element type must match T"); |
| 238 | + static_assert(::cuda::std::is_same_v<::cuda::std::iter_value_t<OutputT>, T>, "Output element type must match T"); |
| 239 | + |
| 240 | + // Need writeable array as scratch space |
| 241 | + auto values = ::cuda::std::array<T, Batches>{}; |
| 242 | +#pragma unroll |
| 243 | + for (int i = 0; i < Batches; ++i) |
| 244 | + { |
| 245 | + values[i] = inputs[i]; |
| 246 | + } |
| 247 | + |
| 248 | + ReduceInplace(values, reduction_op, lane_mask); |
| 249 | + |
| 250 | +#pragma unroll |
| 251 | + for (int i = 0; i < max_out_per_thread; ++i) |
| 252 | + { |
| 253 | + const auto batch_idx = i * LogicalWarpThreads + lane_id; |
| 254 | + if (batch_idx < Batches) |
| 255 | + { |
| 256 | + outputs[i] = values[i]; |
| 257 | + } |
| 258 | + } |
| 259 | + } |
| 260 | + |
| 261 | + // TODO: Public for benchmarking purposes only. |
| 262 | + template <typename InputT, typename ReductionOp> |
| 263 | + _CCCL_DEVICE _CCCL_FORCEINLINE void ReduceInplace( |
| 264 | + InputT& inputs, ReductionOp reduction_op, ::cuda::std::uint32_t lane_mask = ::cuda::device::lane_mask::all().value()) |
| 265 | + { |
| 266 | + static_assert(detail::is_fixed_size_random_access_range_v<InputT>, |
| 267 | + "InputT must support the subscript operator[] and have a compile-time size"); |
| 268 | + static_assert(detail::static_size_v<InputT> == Batches, "Input size must match Batches"); |
| 269 | + static_assert(::cuda::std::is_same_v<::cuda::std::iter_value_t<InputT>, T>, "Input element type must match T"); |
| 270 | +#if defined(_CCCL_ASSERT_DEVICE) |
| 271 | + const auto logical_warp_leader = |
| 272 | + ::cuda::round_down(static_cast<int>(::cuda::ptx::get_sreg_laneid()), LogicalWarpThreads); |
| 273 | + const auto logical_warp_mask = ::cuda::bitmask(logical_warp_leader, LogicalWarpThreads); |
| 274 | +#endif // _CCCL_ASSERT_DEVICE |
| 275 | + _CCCL_ASSERT((lane_mask & logical_warp_mask) == logical_warp_mask, |
| 276 | + "lane_mask must be consistent for each logical warp"); |
| 277 | + |
| 278 | +#pragma unroll |
| 279 | + for (int stride_intra_reduce = 1; stride_intra_reduce < LogicalWarpThreads; stride_intra_reduce *= 2) |
| 280 | + { |
| 281 | + const auto stride_inter_reduce = 2 * stride_intra_reduce; |
| 282 | + const auto is_left_lane = |
| 283 | + static_cast<int>(::cuda::ptx::get_sreg_laneid()) % (2 * stride_intra_reduce) < stride_intra_reduce; |
| 284 | + |
| 285 | +#pragma unroll |
| 286 | + for (int i = 0; i < Batches; i += stride_inter_reduce) |
| 287 | + { |
| 288 | + auto left_value = inputs[i]; |
| 289 | + const auto right_idx = i + stride_intra_reduce; |
| 290 | + // Needed for Batches < LogicalWarpThreads case |
| 291 | + // Chose to redundantly operate on the last batch to avoid relying on default construction of T |
| 292 | + const auto safe_right_idx = right_idx < Batches ? right_idx : Batches - 1; |
| 293 | + auto right_value = inputs[safe_right_idx]; |
| 294 | + // Each left lane exchanges its right value against a right lane's left value |
| 295 | + if (is_left_lane) |
| 296 | + { |
| 297 | + ::cuda::std::swap(left_value, right_value); |
| 298 | + } |
| 299 | + left_value = ::cuda::device::warp_shuffle_xor(left_value, stride_intra_reduce, lane_mask); |
| 300 | + // While the current implementation is possibly faster, another conditional swap here would allow for |
| 301 | + // non-commutative reductions which might be useful for (segmented) scan operations. |
| 302 | + inputs[i] = reduction_op(left_value, right_value); |
| 303 | + } |
| 304 | + } |
| 305 | + // Make sure results are in the beginning of the array instead of strided |
| 306 | +#pragma unroll |
| 307 | + for (int i = 1; i < max_out_per_thread; ++i) |
| 308 | + { |
| 309 | + const auto batch_idx = |
| 310 | + i * LogicalWarpThreads + static_cast<int>(::cuda::ptx::get_sreg_laneid()) % LogicalWarpThreads; |
| 311 | + if (batch_idx < Batches) |
| 312 | + { |
| 313 | + inputs[i] = inputs[i * LogicalWarpThreads]; |
| 314 | + } |
| 315 | + } |
| 316 | + } |
| 317 | + //! @rst |
| 318 | + //! Performs batched sum reduction of Batches arrays. |
| 319 | + //! |
| 320 | + //! Convenience wrapper for ``ReduceBatched`` with ``::cuda::std::plus<>`` operator. |
| 321 | + //! |
| 322 | + //! @smemwarpreuse |
| 323 | + //! |
| 324 | + //! @endrst |
| 325 | + //! |
| 326 | + //! @tparam InputT |
| 327 | + //! **[inferred]** Input array-like type (C-array, cuda::std::array, cuda::std::span, etc.) |
| 328 | + //! |
| 329 | + //! @tparam OutputT |
| 330 | + //! **[inferred]** Output array-like type (C-array, cuda::std::array, cuda::std::span, etc.) |
| 331 | + //! |
| 332 | + //! @param[in] inputs |
| 333 | + //! Statically-sized array-like container of Batches input values from calling thread |
| 334 | + //! |
| 335 | + //! @param[out] outputs |
| 336 | + //! Statically-sized array-like container where thread i stores sums sequentially: |
| 337 | + //! ``outputs[0]`` = sum of array i, ``outputs[1]`` = sum of array (i + LogicalWarpThreads), etc. |
| 338 | + //! |
| 339 | + //! @param[in] lane_mask |
| 340 | + //! Lane mask to restrict the reduction to a subset of the logical warps present in the physical warp. |
| 341 | + //! Default is all logical warps. |
| 342 | + template <typename InputT, typename OutputT> |
| 343 | + _CCCL_DEVICE _CCCL_FORCEINLINE void Sum( |
| 344 | + const InputT& inputs, OutputT& outputs, ::cuda::std::uint32_t lane_mask = ::cuda::device::lane_mask::all().value()) |
| 345 | + { |
| 346 | + Reduce(inputs, outputs, ::cuda::std::plus<>{}, lane_mask); |
| 347 | + } |
| 348 | + |
| 349 | + //! @} |
| 350 | +}; |
| 351 | + |
| 352 | +CUB_NAMESPACE_END |
0 commit comments