Skip to content

Commit cc55e02

Browse files
authored
Add TMA scheduler for outer2D reduction (#5966)
Add auto-scheduler for TMA outer reduction. Similar schedule to [PR#5926](#5926). Fixed block dims, TMA tile sizes, and unroll factor. Always target grid reduction with a very simple heuristic. <img width="1671" height="667" alt="2026-02-17_05-35" src="https://github.com/user-attachments/assets/2fe57532-30ec-40da-b1be-f3f2b5bf3567" />
1 parent ba2eb55 commit cc55e02

File tree

5 files changed

+502
-12
lines changed

5 files changed

+502
-12
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ list(APPEND NVFUSER_SRCS
333333
${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp
334334
${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp
335335
${NVFUSER_SRCS_DIR}/scheduler/reduction_non_tma.cpp
336+
${NVFUSER_SRCS_DIR}/scheduler/reduction_outer_tma.cpp
336337
${NVFUSER_SRCS_DIR}/scheduler/reduction_tma.cpp
337338
${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp
338339
${NVFUSER_SRCS_DIR}/scheduler/registry.cpp

csrc/scheduler/reduction.cpp

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "instrumentation.h"
1414
#include "scheduler/debug_utils.h"
1515
#include "scheduler/reduction_non_tma.h"
16+
#include "scheduler/reduction_outer_tma.h"
1617
#include "scheduler/reduction_tma.h"
1718
#include "scheduler/reduction_utils.h"
1819
#include "scheduler/registry_utils.h"
@@ -220,7 +221,8 @@ bool mayUseTma(
220221

221222
// Require reduction dim fits into smem, until we add iteration over large
222223
// reduction dim.
223-
const int64_t smem_elems = (dev_prop->sharedMemPerBlockOptin * 8) /
224+
const int64_t smem_elems =
225+
(static_cast<int64_t>(dev_prop->sharedMemPerBlockOptin) * 8) /
224226
props.max_dtype_size_bit_for_vectorization;
225227

226228
if (props.inner_most_dimension_numel > smem_elems) {
@@ -232,11 +234,62 @@ bool mayUseTma(
232234
return false;
233235
}
234236

235-
// Like vectorization, TMA requires 16-bytes alignment
237+
// Require that the innermost dim is contiguous.
236238
if (props.vectorize_factor <= 1) {
237239
return false;
238240
}
239241

242+
uint64_t vect_bits =
243+
props.vectorize_factor * props.max_dtype_size_bit_for_vectorization;
244+
245+
// TMA requires 16-byte alignment (128 bits) for memory transactions
246+
if (vect_bits % 128 != 0) {
247+
return false;
248+
}
249+
250+
return true;
251+
}
252+
253+
bool mayUseTmaOuter(
254+
const reduction_scheduler_utils::FusionRuntimeProperties& props) {
255+
auto dev_prop = at::cuda::getCurrentDeviceProperties();
256+
257+
if (dev_prop->major < 9) {
258+
return false;
259+
}
260+
261+
// Require outer reduction
262+
if (props.fastest_dim_reduction) {
263+
return false;
264+
}
265+
266+
int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
267+
uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes;
268+
269+
// Minimum TMA transfer size
270+
uint64_t min_tma_bytes = 16384;
271+
if (total_reduction_bytes < min_tma_bytes) {
272+
return false;
273+
}
274+
275+
// TMA tile may exceed Smem size for multiple tensors.
276+
if (props.n_tensor_inputs != 1) {
277+
return false;
278+
}
279+
280+
// Require that the innermost dim is contiguous.
281+
if (props.vectorize_factor <= 1) {
282+
return false;
283+
}
284+
285+
uint64_t vect_bits =
286+
props.vectorize_factor * props.max_dtype_size_bit_for_vectorization;
287+
288+
// TMA requires 16-byte alignment (128 bits) for memory transactions
289+
if (vect_bits % 128 != 0) {
290+
return false;
291+
}
292+
240293
return true;
241294
}
242295
} // namespace
@@ -250,14 +303,22 @@ std::unique_ptr<HeuristicParams> ReductionScheduler::computeHeuristics(
250303
auto props = reduction_scheduler_utils::getFusionRuntimeProperties(
251304
fusion, runtime_info, data_cache);
252305

253-
bool use_tma =
254-
mayUseTma(props) && isOptionEnabled(EnableOption::TmaReduction);
306+
bool tma_enabled = isOptionEnabled(EnableOption::TmaReduction);
255307

256308
std::unique_ptr<HeuristicParams> rparams = nullptr;
257-
if (use_tma) {
309+
310+
// Try outer TMA scheduler for outer reductions
311+
if (tma_enabled && mayUseTmaOuter(props)) {
312+
rparams = reduction::outer_tma::getReductionHeuristics(
313+
fusion, runtime_info, data_cache, props);
314+
}
315+
316+
// Try inner TMA scheduler for inner reductions
317+
if (rparams == nullptr && tma_enabled && mayUseTma(props)) {
258318
rparams = reduction::tma::getReductionHeuristics(
259319
fusion, runtime_info, data_cache, props);
260320
}
321+
261322
// Fallback to non-TMA scheduler if TMA is not applicable
262323
if (rparams == nullptr) {
263324
rparams = reduction::non_tma::getReductionHeuristics(
@@ -271,7 +332,11 @@ void ReductionScheduler::schedule(
271332
Fusion* fusion,
272333
const HeuristicParams* params) {
273334
FUSER_PERF_SCOPE("ReductionScheduler::schedule");
274-
if (auto* tma_params = dynamic_cast<const TmaInnerReductionParams*>(params)) {
335+
if (auto* outer_tma_params =
336+
dynamic_cast<const TmaOuterReductionParams*>(params)) {
337+
reduction::outer_tma::scheduleReduction(fusion, outer_tma_params);
338+
} else if (
339+
auto* tma_params = dynamic_cast<const TmaInnerReductionParams*>(params)) {
275340
reduction::tma::scheduleReduction(fusion, tma_params);
276341
} else {
277342
auto rparams = dynamic_cast<const ReductionParams*>(params);
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
9+
#include "scheduler/reduction_outer_tma.h"
10+
11+
#include <ATen/cuda/CUDAContext.h>
12+
13+
#include "scheduler/cache_policy_refiner.h"
14+
#include "scheduler/reduction_utils.h"
15+
#include "scheduler/runtime_info.h"
16+
#include "scheduler/tools/inlining.h"
17+
#include "scheduler/utils.h"
18+
19+
namespace nvfuser {
20+
namespace reduction {
21+
namespace outer_tma {
22+
23+
std::unique_ptr<TmaOuterReductionParams> getReductionHeuristics(
24+
Fusion* fusion,
25+
SchedulerRuntimeInfo& runtime_info,
26+
HeuristicDataCache* data_cache,
27+
const reduction_scheduler_utils::FusionRuntimeProperties& props) {
28+
FusionGuard fg(fusion);
29+
30+
// TODO: These heuristics are stubbed out based on the manual test
31+
// (TmaOuterReductionManualTest::Basic). They need proper tuning.
32+
33+
// 2D thread block: TIDx covers iteration, TIDy covers reduction
34+
const int64_t bdimx = 32;
35+
const int64_t bdimy = 16;
36+
37+
// TMA tile sizes. Unroll factors are derived from these and thread dims.
38+
const int64_t tma_tile_i = 128;
39+
const int64_t tma_tile_r = 128;
40+
41+
NVF_ERROR(tma_tile_i % bdimx == 0);
42+
NVF_ERROR(tma_tile_r % bdimy == 0);
43+
44+
const int64_t iter_unroll_factor = tma_tile_i / bdimx;
45+
46+
// Grid dimension for parallelizing the outer reduction across CTAs.
47+
// Modeled after the manual test: clamp lastPow2(outer_size / 256) to [1, 8].
48+
const int64_t outer_size = props.total_reduction_numel;
49+
int64_t grdim = std::max<int64_t>(
50+
1, std::min<int64_t>(8, scheduler_utils::lastPow2(outer_size / 256)));
51+
52+
auto params = std::make_unique<TmaOuterReductionParams>();
53+
params->bdimx = bdimx;
54+
params->bdimy = bdimy;
55+
params->tma_tile_i = tma_tile_i;
56+
params->tma_tile_r = tma_tile_r;
57+
params->iter_unroll_factor = iter_unroll_factor;
58+
params->grdim = grdim;
59+
60+
params->tag = "Outer Reduction TMA heuristics";
61+
params->cparams.index_type = runtime_info.getIndexType();
62+
63+
return params;
64+
}
65+
66+
void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
67+
FusionGuard fg(fusion);
68+
69+
const int64_t bdimx = rparams->bdimx;
70+
const int64_t bdimy = rparams->bdimy;
71+
const int64_t tma_tile_i = rparams->tma_tile_i;
72+
const int64_t tma_tile_r = rparams->tma_tile_r;
73+
const int64_t iter_unroll_factor = rparams->iter_unroll_factor;
74+
const int64_t grdim = rparams->grdim;
75+
76+
NVF_ERROR(tma_tile_i % bdimx == 0);
77+
NVF_ERROR(tma_tile_r % bdimy == 0);
78+
79+
// Phase 1: Cache inputs into shared memory via TMA
80+
auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);
81+
82+
scheduler_utils::clearMemorySpace(fusion);
83+
84+
scheduler_utils::cacheAndForkOutputs(fusion, true);
85+
86+
scheduler_utils::prepareForMemoryTypePromotion(fusion);
87+
88+
std::vector<TensorView*> tma_tvs;
89+
for (auto [tv, input_idx] : cached_inputs) {
90+
if (auto load_op = dynamic_cast<LoadStoreOp*>(tv->definition())) {
91+
load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile);
92+
tv->setMemoryType(MemoryType::Shared);
93+
tma_tvs.push_back(tv);
94+
}
95+
}
96+
NVF_ERROR(!tma_tvs.empty());
97+
98+
auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
99+
NVF_ERROR(!reduction_tvs.empty());
100+
TensorView* reduction_tv = reduction_tvs.at(0);
101+
102+
// canonicalizeReduction with schedule_3d=false merges into [I, R] form.
103+
// For outer reduction, we want [R, I], so we reorder after canonicalization.
104+
auto dim_analysis =
105+
scheduler_utils::canonicalizeReduction(fusion, reduction_tv, false);
106+
bool has_iter_axis = dim_analysis.first;
107+
bool has_red_axis = dim_analysis.second;
108+
NVF_ERROR(has_iter_axis && has_red_axis);
109+
110+
// Reorder from [I, R] -> [R, I] for outer reduction pattern
111+
reduction_tv->reorder({{0, 1}, {1, 0}});
112+
113+
// The reduction_tv already has a cacheBefore from cacheAndForkOutputs.
114+
// Use reduction_tv directly as our reduction reference for scheduling.
115+
TensorView* redu_tv = reduction_tv;
116+
117+
// After canonicalization + reorder: [R, I]
118+
const int64_t outer_reduce_axis = 0;
119+
120+
// Phase 2: Schedule TMA tensor with 2D TMA tiling
121+
// Apply transforms to the TMA smem TV.
122+
// We start from the TMA TV, which shares the same logical domain as the
123+
// reduction TV's producer.
124+
TensorView* tma_tv = tma_tvs[0];
125+
126+
// [R, I] -> [R/tma_tile_r, tma_tile_r, I]
127+
tma_tv->split(outer_reduce_axis, tma_tile_r);
128+
129+
// [R/tma_tile_r, tma_tile_r, I] -> [R/tma_tile_r, tma_tile_r, I/tma_tile_i,
130+
// tma_tile_i]
131+
tma_tv->split(2, tma_tile_i);
132+
133+
// Split outer reduction for grid parallelization (BIDy)
134+
// [R/tma_tile_r, tma_tile_r, I/tma_tile_i, tma_tile_i]
135+
// -> [grdim, R', tma_tile_r, I/tma_tile_i, tma_tile_i]
136+
// 0 1 2 3 4
137+
tma_tv->split(0, grdim, false);
138+
139+
// Phase 3: Propagate TMA tiling to all tensors
140+
TransformPropagator propagator(tma_tv);
141+
MaxLogicalDomainInfoSpanningTree(tma_tv).traverse(&propagator);
142+
143+
// Phase 4: Parallelize TMA tensor
144+
tma_tv->axis(0)->parallelize(ParallelType::BIDy);
145+
tma_tv->axis(1)->parallelize(ParallelType::Serial);
146+
tma_tv->axis(2)->parallelize(ParallelType::Bulk); // reduction tile
147+
tma_tv->axis(3)->parallelize(ParallelType::BIDx);
148+
tma_tv->axis(4)->parallelize(ParallelType::Bulk); // iteration tile
149+
150+
// Parallelize remaining TMA tvs to match
151+
scheduler_utils::parallelizeAllLike(tma_tv, tma_tvs);
152+
153+
// Phase 5: Sub-split TMA tiles into thread dims
154+
// Split tma_tile_i (axis 4) into [bdimx, iter_unroll]
155+
redu_tv->split(4, iter_unroll_factor);
156+
157+
// Split tma_tile_r (axis 2) into [redu_unroll, bdimy]
158+
redu_tv->split(2, bdimy);
159+
// Now: [grdim, R', redu_unroll, bdimy, I/tma_tile_i, bdimx, iter_unroll]
160+
// 0 1 2 3 4 5 6
161+
162+
// Phase 6: Parallelize reduction tensor
163+
redu_tv->axis(0)->parallelize(ParallelType::BIDy);
164+
redu_tv->axis(1)->parallelize(ParallelType::Serial);
165+
redu_tv->axis(2)->parallelize(ParallelType::Unroll); // redu_unroll
166+
redu_tv->axis(3)->parallelize(ParallelType::TIDy); // bdimy
167+
redu_tv->axis(4)->parallelize(ParallelType::BIDx);
168+
redu_tv->axis(5)->parallelize(ParallelType::TIDx); // bdimx
169+
// Use Vectorize so it gets converted to Group for iterGroupedGridReduce
170+
redu_tv->axis(6)->parallelize(ParallelType::Vectorize); // iter_unroll
171+
172+
// Phase 7: rFactor for grid reduction
173+
// rFactor reduction axes that are not thread-parallelized
174+
std::vector<int64_t> rfactor_axes;
175+
for (int64_t i = 0; i < redu_tv->nDims(); i++) {
176+
if (redu_tv->axis(i)->isReduction() && !redu_tv->axis(i)->isThread()) {
177+
rfactor_axes.push_back(i);
178+
}
179+
}
180+
181+
TensorView* reference_tv = redu_tv;
182+
if (!rfactor_axes.empty()) {
183+
reference_tv = redu_tv->rFactor(rfactor_axes);
184+
}
185+
186+
// Phase 8: Propagate thread-level splits to non-TMA TVs
187+
std::vector<TensorView*> non_tma_tvs =
188+
ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()});
189+
TransformPropagator non_tma_propagator(reference_tv);
190+
SetSelector non_tma_selector({non_tma_tvs.begin(), non_tma_tvs.end()});
191+
MaxLogicalDomainInfoSpanningTree(reference_tv, &non_tma_selector)
192+
.traverse(&non_tma_propagator);
193+
194+
// Phase 9: Propagate parallelization with iter-grouped reduction
195+
const bool use_iter_grouped_reduction = true;
196+
197+
if (reference_tv != redu_tv) {
198+
reduction_scheduler_utils::propagateRFactor(
199+
reference_tv, redu_tv, reduction_tvs);
200+
non_tma_tvs =
201+
ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()});
202+
}
203+
204+
// Collect output TVs for unroll_vectorizable_cached_tvs
205+
std::unordered_set<TensorView*> output_tvs;
206+
for (auto output : fusion->outputs()) {
207+
if (auto tv = dynamic_cast<TensorView*>(output)) {
208+
output_tvs.insert(tv);
209+
}
210+
}
211+
212+
reduction_scheduler_utils::propagateParallelization(
213+
redu_tv,
214+
reference_tv,
215+
/*is_unroll_or_vectorization=*/true,
216+
use_iter_grouped_reduction,
217+
reduction_tvs,
218+
/*unroll_vectorizable_cached_tvs=*/output_tvs,
219+
/*selected_tvs=*/non_tma_tvs);
220+
221+
// Phase 10: Inline and refine
222+
inlineMost();
223+
224+
refineCachePolicy(fusion);
225+
}
226+
227+
} // namespace outer_tma
228+
} // namespace reduction
229+
} // namespace nvfuser

0 commit comments

Comments
 (0)