Skip to content

Commit b3b1173

Browse files
committed
Implement WarpReduceBatched
Includes extensive testing
1 parent 0b88dab commit b3b1173

File tree

2 files changed

+858
-0
lines changed

2 files changed

+858
-0
lines changed
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
// SPDX-License-Identifier: BSD-3
3+
4+
//! @file
5+
//! @rst
6+
//! The ``cub::WarpReduceBatched`` class provides :ref:`collective <collective-primitives>` methods for
7+
//! performing batched parallel reductions of multiple arrays partitioned across a CUDA thread warp.
8+
//! @endrst
9+
10+
#pragma once
11+
12+
#include <cub/config.cuh>
13+
14+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
15+
# pragma GCC system_header
16+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
17+
# pragma clang system_header
18+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
19+
# pragma system_header
20+
#endif // no system header
21+
22+
#include <cub/detail/type_traits.cuh>
23+
#include <cub/thread/thread_operators.cuh>
24+
#include <cub/util_type.cuh>
25+
26+
#include <cuda/__ptx/instructions/get_sreg.h>
27+
#include <cuda/bit>
28+
#include <cuda/cmath>
29+
#include <cuda/std/__functional/operations.h>
30+
#include <cuda/std/__iterator/iterator_traits.h>
31+
#include <cuda/warp>
32+
33+
CUB_NAMESPACE_BEGIN
34+
35+
//! @rst
36+
//! The ``WarpReduceBatched`` class provides :ref:`collective <collective-primitives>` methods for performing
37+
//! batched parallel reductions of multiple arrays partitioned across a CUDA thread warp using the WSPRO algorithm.
38+
//!
39+
//! Overview
40+
//! ++++++++
41+
//!
42+
//! - Performs batched reductions of Batches arrays, each containing LogicalWarpThreads elements
43+
//! - Completes in Batches-1 stages using the WSPRO (Warp Shuffle Parallel Reduction Optimization) algorithm
44+
//! - Standard approach (Batches sequential calls to WarpReduce) requires ``Batches * log2(LogicalWarpThreads)`` stages
45+
//! - Best performance when ``Batches == LogicalWarpThreads`` (both powers of 2)
46+
//! - **Output semantics:** Thread i's ``outputs`` array contains:
47+
//! - ``outputs[0]`` = reduction of array i
48+
//! - ``outputs[1]`` = reduction of array (i + LogicalWarpThreads), if it exists
49+
//! - ``outputs[k]`` = reduction of array (i + k * LogicalWarpThreads)
50+
//!
51+
//! Performance Characteristics
52+
//! +++++++++++++++++++++++++++
53+
//!
54+
//! - **Stage count:** Batches-1 stages vs Batches * log2(LogicalWarpThreads) for sequential WarpReduce calls
55+
//! - **Example (Batches=8, LogicalWarpThreads=8):** 7 stages (batched) vs 24 stages (sequential) = 3.4x reduction
56+
//! - **Example (Batches=16, LogicalWarpThreads=16):** 15 stages (batched) vs 64 stages (sequential) = 4.3x reduction
57+
//! - Uses warp ``SHFL`` instructions for communication
58+
//! - No shared memory required
59+
//!
60+
//! When to Use
61+
//! +++++++++++
62+
//!
63+
//! - When you need to reduce multiple independent batches within a warp
64+
//! - When ``Batches`` and ``LogicalWarpThreads`` are powers of 2 (required for ``LogicalWarpThreads``, recommended for
65+
//! ``Batches``)
66+
//! - When computing many small reductions (e.g., per-pixel reductions)
67+
//!
68+
//! Simple Examples
69+
//! +++++++++++++++
70+
//!
71+
//! @warpcollective{WarpReduceBatched}
72+
//!
73+
//! The code snippet below illustrates reduction of 8 batches of 8 elements each:
74+
//!
75+
//! .. code-block:: c++
76+
//!
77+
//! #include <cub/cub.cuh>
78+
//!
79+
//! __global__ void ExampleKernel(...)
80+
//! {
81+
//! // Specialize WarpReduceBatched for 8 batches of 8 int elements
82+
//! using WarpReduceBatched = cub::WarpReduceBatched<int, 8, 8>;
83+
//!
84+
//! // Allocate shared memory (none needed for shuffle-based implementation)
85+
//! __shared__ typename WarpReduceBatched::TempStorage temp_storage;
86+
//!
87+
//! // Each thread provides 8 input values (one element from each of 8 arrays)
88+
//! int thread_data[8];
89+
//! for (int i = 0; i < 8; i++)
90+
//! {
91+
//! thread_data[i] = ...; // Load element i from thread's position
92+
//! }
93+
//!
94+
//! // Perform batched reduction (7 stages vs 24 for sequential)
95+
//! int results[8];
96+
//! WarpReduceBatched(temp_storage).ReduceBatched(thread_data, results, ::cuda::std::plus<>{});
97+
//!
98+
//! // Thread i now has the sum of array i in results[i]
99+
//! }
100+
//!
101+
//! @endrst
102+
//!
103+
//! @tparam T
104+
//! The reduction input/output element type
105+
//!
106+
//! @tparam Batches
107+
//! The number of arrays to reduce in batch. Best performance when Batches = LogicalWarpThreads.
108+
//!
109+
//! @tparam LogicalWarpThreads
110+
//! The number of threads per logical warp / elements per array. Must be a power-of-two in range [2, 32].
111+
//! Default is the warp size of the targeted CUDA compute-capability (e.g., 32).
112+
//!
113+
template <typename T, int Batches, int LogicalWarpThreads = detail::warp_threads>
114+
class WarpReduceBatched
115+
{
116+
static_assert(::cuda::is_power_of_two(LogicalWarpThreads), "LogicalWarpThreads must be a power of two");
117+
// TODO: Should we allow LogicalWarpThreads = 1? (in which case everything is just no-op/copy)
118+
static_assert(LogicalWarpThreads > 1 && LogicalWarpThreads <= detail::warp_threads,
119+
"LogicalWarpThreads must be in the range [2, 32]");
120+
// TODO: Should we restrict to Batches > 1?
121+
static_assert(Batches >= 1, "Batches must be >= 1");
122+
123+
private:
124+
//---------------------------------------------------------------------
125+
// Constants and type definitions
126+
//---------------------------------------------------------------------
127+
128+
/// Whether the logical warp size and the PTX warp size coincide
129+
static constexpr auto is_arch_warp = (LogicalWarpThreads == detail::warp_threads);
130+
131+
static constexpr auto max_out_per_thread = ::cuda::ceil_div(Batches, LogicalWarpThreads);
132+
133+
//---------------------------------------------------------------------
134+
// Thread fields
135+
//---------------------------------------------------------------------
136+
137+
/// Lane index in logical warp
138+
int lane_id;
139+
140+
public:
141+
//! @name Collective constructors
142+
//! @{
143+
144+
//! @rst
145+
//! Collective constructor using the specified memory allocation as temporary storage.
146+
//! Logical warp and lane identifiers are constructed from ``threadIdx.x``.
147+
//! @endrst
148+
_CCCL_DEVICE _CCCL_FORCEINLINE WarpReduceBatched()
149+
: lane_id(static_cast<int>(::cuda::ptx::get_sreg_laneid()))
150+
{
151+
if (!is_arch_warp)
152+
{
153+
lane_id = lane_id % LogicalWarpThreads;
154+
}
155+
}
156+
157+
//! @}
158+
//! @name Batched reductions
159+
//! @{
160+
161+
//! @rst
162+
//! Performs batched reduction of ``Batches`` arrays using the specified binary reduction operator.
163+
//!
164+
//! Each thread provides ``Batches`` input values (one element from each batch).
165+
//! The warp collectively reduces each of the ``Batches`` batches (each containing ``LogicalWarpThreads`` elements).
166+
//! Thread *i* stores results sequentially in its ``outputs`` array:
167+
//! ``outputs[0]`` = result of batch *i*, ``outputs[1]`` = result of batch *(i + LogicalWarpThreads)*, etc.
168+
//!
169+
//! **Algorithm Performance:**
170+
//!
171+
//! - Completes in ``Batches - 1 + log2(LogicalWarpThreads / Batches)`` stages
172+
//! - vs ``Batches * log2(LogicalWarpThreads)`` stages for ``Batches`` sequential ``WarpReduce`` calls
173+
//! - Example: ``Batches=8``, ``LogicalWarpThreads=8`` -> 7 stages (batched) vs 24 stages (sequential)
174+
//!
175+
//! Snippet
176+
//! +++++++
177+
//!
178+
//! The code snippet below illustrates batched reduction of 8 batches:
179+
//!
180+
//! .. code-block:: c++
181+
//!
182+
//! #include <cub/cub.cuh>
183+
//!
184+
//! __global__ void ExampleKernel(...)
185+
//! {
186+
//! using WarpReduceBatched = cub::WarpReduceBatched<int, 8, 8>;
187+
//!
188+
//! cuda::std::array<int, 8> inputs = {...}; // Each thread provides 8 values
189+
//! cuda::std::array<int, 1> output;
190+
//!
191+
//! WarpReduceBatched.Reduce(
192+
//! inputs, output, cuda::std::plus<>{});
193+
//!
194+
//! // Logical warp lane i now has sum of batch i in output[0]
195+
//! }
196+
//!
197+
//! @endrst
198+
//!
199+
//! @tparam InputT
200+
//! **[inferred]** Input array-like type (C-array, cuda::std::array, cuda::std::span, etc.)
201+
//!
202+
//! @tparam OutputT
203+
//! **[inferred]** Output array-like type (C-array, cuda::std::array, cuda::std::span, etc.)
204+
//!
205+
//! @tparam ReductionOp
206+
//! **[inferred]** Binary reduction operator type having member
207+
//! `T operator()(const T &a, const T &b)`
208+
//!
209+
//! @param[in] inputs
210+
//! Statically-sized array-like container of Batches input values from calling thread
211+
//!
212+
//! @param[out] outputs
213+
//! Statically-sized array-like container where thread i stores reductions sequentially:
214+
//! ``outputs[0]`` = result of batch i, ``outputs[1]`` = result of batch (i + LogicalWarpThreads), etc.
215+
//!
216+
//! @param[in] reduction_op
217+
//! Binary reduction operator
218+
//!
219+
//! @param[in] lane_mask
220+
//! Lane mask to restrict the reduction to a subset of the logical warps present in the physical warp.
221+
//! Default is all logical warps.
222+
template <typename InputT, typename OutputT, typename ReductionOp>
223+
_CCCL_DEVICE _CCCL_FORCEINLINE void
224+
Reduce(const InputT& inputs,
225+
OutputT& outputs,
226+
ReductionOp reduction_op,
227+
::cuda::std::uint32_t lane_mask = ::cuda::device::lane_mask::all().value())
228+
{
229+
static_assert(::cub::detail::is_fixed_size_random_access_range_v<InputT>,
230+
"InputT must support the subscript operator[] and have a compile-time size");
231+
static_assert(::cub::detail::is_fixed_size_random_access_range_v<OutputT>,
232+
"OutputT must support the subscript operator[] and have a compile-time size");
233+
static_assert(::cub::detail::static_size_v<InputT> == Batches, "Input size must match Batches");
234+
static_assert(::cub::detail::static_size_v<OutputT> == max_out_per_thread,
235+
"Output size must match ceil_div(Batches, LogicalWarpThreads)");
236+
// These restrictions could be relaxed to allow type-conversions
237+
static_assert(::cuda::std::is_same_v<::cuda::std::iter_value_t<InputT>, T>, "Input element type must match T");
238+
static_assert(::cuda::std::is_same_v<::cuda::std::iter_value_t<OutputT>, T>, "Output element type must match T");
239+
240+
// Need writeable array as scratch space
241+
auto values = ::cuda::std::array<T, Batches>{};
242+
#pragma unroll
243+
for (int i = 0; i < Batches; ++i)
244+
{
245+
values[i] = inputs[i];
246+
}
247+
248+
ReduceInplace(values, reduction_op, lane_mask);
249+
250+
#pragma unroll
251+
for (int i = 0; i < max_out_per_thread; ++i)
252+
{
253+
const auto batch_idx = i * LogicalWarpThreads + lane_id;
254+
if (batch_idx < Batches)
255+
{
256+
outputs[i] = values[i];
257+
}
258+
}
259+
}
260+
261+
// TODO: Public for benchmarking purposes only.
262+
template <typename InputT, typename ReductionOp>
263+
_CCCL_DEVICE _CCCL_FORCEINLINE void ReduceInplace(
264+
InputT& inputs, ReductionOp reduction_op, ::cuda::std::uint32_t lane_mask = ::cuda::device::lane_mask::all().value())
265+
{
266+
static_assert(detail::is_fixed_size_random_access_range_v<InputT>,
267+
"InputT must support the subscript operator[] and have a compile-time size");
268+
static_assert(detail::static_size_v<InputT> == Batches, "Input size must match Batches");
269+
static_assert(::cuda::std::is_same_v<::cuda::std::iter_value_t<InputT>, T>, "Input element type must match T");
270+
#if defined(_CCCL_ASSERT_DEVICE)
271+
const auto logical_warp_leader =
272+
::cuda::round_down(static_cast<int>(::cuda::ptx::get_sreg_laneid()), LogicalWarpThreads);
273+
const auto logical_warp_mask = ::cuda::bitmask(logical_warp_leader, LogicalWarpThreads);
274+
#endif // _CCCL_ASSERT_DEVICE
275+
_CCCL_ASSERT((lane_mask & logical_warp_mask) == logical_warp_mask,
276+
"lane_mask must be consistent for each logical warp");
277+
278+
#pragma unroll
279+
for (int stride_intra_reduce = 1; stride_intra_reduce < LogicalWarpThreads; stride_intra_reduce *= 2)
280+
{
281+
const auto stride_inter_reduce = 2 * stride_intra_reduce;
282+
const auto is_left_lane =
283+
static_cast<int>(::cuda::ptx::get_sreg_laneid()) % (2 * stride_intra_reduce) < stride_intra_reduce;
284+
285+
#pragma unroll
286+
for (int i = 0; i < Batches; i += stride_inter_reduce)
287+
{
288+
auto left_value = inputs[i];
289+
const auto right_idx = i + stride_intra_reduce;
290+
// Needed for Batches < LogicalWarpThreads case
291+
// Chose to redundantly operate on the last batch to avoid relying on default construction of T
292+
const auto safe_right_idx = right_idx < Batches ? right_idx : Batches - 1;
293+
auto right_value = inputs[safe_right_idx];
294+
// Each left lane exchanges its right value against a right lane's left value
295+
if (is_left_lane)
296+
{
297+
::cuda::std::swap(left_value, right_value);
298+
}
299+
left_value = ::cuda::device::warp_shuffle_xor(left_value, stride_intra_reduce, lane_mask);
300+
// While the current implementation is possibly faster, another conditional swap here would allow for
301+
// non-commutative reductions which might be useful for (segmented) scan operations.
302+
inputs[i] = reduction_op(left_value, right_value);
303+
}
304+
}
305+
// Make sure results are in the beginning of the array instead of strided
306+
#pragma unroll
307+
for (int i = 1; i < max_out_per_thread; ++i)
308+
{
309+
const auto batch_idx =
310+
i * LogicalWarpThreads + static_cast<int>(::cuda::ptx::get_sreg_laneid()) % LogicalWarpThreads;
311+
if (batch_idx < Batches)
312+
{
313+
inputs[i] = inputs[i * LogicalWarpThreads];
314+
}
315+
}
316+
}
317+
//! @rst
318+
//! Performs batched sum reduction of Batches arrays.
319+
//!
320+
//! Convenience wrapper for ``ReduceBatched`` with ``::cuda::std::plus<>`` operator.
321+
//!
322+
//! @smemwarpreuse
323+
//!
324+
//! @endrst
325+
//!
326+
//! @tparam InputT
327+
//! **[inferred]** Input array-like type (C-array, cuda::std::array, cuda::std::span, etc.)
328+
//!
329+
//! @tparam OutputT
330+
//! **[inferred]** Output array-like type (C-array, cuda::std::array, cuda::std::span, etc.)
331+
//!
332+
//! @param[in] inputs
333+
//! Statically-sized array-like container of Batches input values from calling thread
334+
//!
335+
//! @param[out] outputs
336+
//! Statically-sized array-like container where thread i stores sums sequentially:
337+
//! ``outputs[0]`` = sum of array i, ``outputs[1]`` = sum of array (i + LogicalWarpThreads), etc.
338+
//!
339+
//! @param[in] lane_mask
340+
//! Lane mask to restrict the reduction to a subset of the logical warps present in the physical warp.
341+
//! Default is all logical warps.
342+
template <typename InputT, typename OutputT>
343+
_CCCL_DEVICE _CCCL_FORCEINLINE void Sum(
344+
const InputT& inputs, OutputT& outputs, ::cuda::std::uint32_t lane_mask = ::cuda::device::lane_mask::all().value())
345+
{
346+
Reduce(inputs, outputs, ::cuda::std::plus<>{}, lane_mask);
347+
}
348+
349+
//! @}
350+
};
351+
352+
CUB_NAMESPACE_END

0 commit comments

Comments
 (0)