-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathreduction_outer_tma.cpp
More file actions
251 lines (201 loc) · 8.91 KB
/
reduction_outer_tma.cpp
File metadata and controls
251 lines (201 loc) · 8.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include "scheduler/reduction_outer_tma.h"
#include <ATen/cuda/CUDAContext.h>
#include "ir/utils.h"
#include "scheduler/cache_policy_refiner.h"
#include "scheduler/reduction_utils.h"
#include "scheduler/runtime_info.h"
#include "scheduler/tools/inlining.h"
#include "scheduler/utils.h"
namespace nvfuser {
namespace reduction {
namespace outer_tma {
std::unique_ptr<TmaOuterReductionParams> getReductionHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicDataCache* data_cache,
const reduction_scheduler_utils::FusionRuntimeProperties& props) {
FusionGuard fg(fusion);
// TODO: These heuristics are stubbed out based on the manual test
// (TmaOuterReductionManualTest::Basic). They need proper tuning.
// 2D thread block: TIDx covers iteration, TIDy covers reduction
const int64_t bdimx = 32;
const int64_t bdimy = 16;
// Compute TMA tile sizes based on available shared memory budget.
// Each input tensor gets a tma_tile_r * tma_tile_i tile in smem.
auto dev_prop = at::cuda::getCurrentDeviceProperties();
const int64_t smem_bytes = (int64_t)dev_prop->sharedMemPerBlockOptin;
const int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
const int64_t n_inputs = std::max(props.n_tensor_inputs, (int64_t)1);
const int64_t smem_per_input = smem_bytes / n_inputs;
// Start with default tile sizes and shrink if they exceed the per-input
// smem budget. Tile sizes must remain multiples of thread block dims.
int64_t tma_tile_i = 128;
int64_t tma_tile_r = 128;
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_r > bdimy) {
tma_tile_r /= 2;
}
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_i > bdimx) {
tma_tile_i /= 2;
}
NVF_ERROR(tma_tile_i % bdimx == 0);
NVF_ERROR(tma_tile_r % bdimy == 0);
const int64_t iter_unroll_factor = tma_tile_i / bdimx;
// Grid dimension for parallelizing the outer reduction across CTAs.
// Modeled after the manual test: clamp lastPow2(outer_size / 256) to [1, 8].
const int64_t outer_size = props.total_reduction_numel;
int64_t grdim = std::max<int64_t>(
1, std::min<int64_t>(8, scheduler_utils::lastPow2(outer_size / 256)));
auto params = std::make_unique<TmaOuterReductionParams>();
params->bdimx = bdimx;
params->bdimy = bdimy;
params->tma_tile_i = tma_tile_i;
params->tma_tile_r = tma_tile_r;
params->iter_unroll_factor = iter_unroll_factor;
params->grdim = grdim;
params->tag = "Outer Reduction TMA heuristics";
params->cparams.index_type = runtime_info.getIndexType();
return params;
}
void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
FusionGuard fg(fusion);
const int64_t bdimx = rparams->bdimx;
const int64_t bdimy = rparams->bdimy;
const int64_t tma_tile_i = rparams->tma_tile_i;
const int64_t tma_tile_r = rparams->tma_tile_r;
const int64_t iter_unroll_factor = rparams->iter_unroll_factor;
const int64_t grdim = rparams->grdim;
NVF_ERROR(tma_tile_i % bdimx == 0);
NVF_ERROR(tma_tile_r % bdimy == 0);
// Phase 1: Cache inputs into shared memory via TMA
auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);
scheduler_utils::clearMemorySpace(fusion);
scheduler_utils::cacheAndForkOutputs(fusion, true);
scheduler_utils::prepareForMemoryTypePromotion(fusion);
std::vector<TensorView*> tma_tvs;
for (auto [tv, input_idx] : cached_inputs) {
if (auto load_op = dynamic_cast<LoadStoreOp*>(tv->definition())) {
load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile);
tv->setMemoryType(MemoryType::Shared);
tma_tvs.push_back(tv);
}
}
NVF_ERROR(!tma_tvs.empty());
auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
NVF_ERROR(!reduction_tvs.empty());
TensorView* reduction_tv = reduction_tvs.at(0);
// canonicalizeReduction with schedule_3d=false merges into [I, R] form.
// For outer reduction, we want [R, I], so we reorder after canonicalization.
auto dim_analysis =
scheduler_utils::canonicalizeReduction(fusion, reduction_tv, false);
bool has_iter_axis = dim_analysis.first;
bool has_red_axis = dim_analysis.second;
NVF_ERROR(has_iter_axis && has_red_axis);
// Reorder from [I, R] -> [R, I] for outer reduction pattern
reduction_tv->reorder({{0, 1}, {1, 0}});
// Propagate the canonicalized [R, I] merges to all tensors.
TransformPropagator canon_propagator(reduction_tv);
MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&canon_propagator);
// The reduction_tv already has a cacheBefore from cacheAndForkOutputs.
// Use reduction_tv directly as our reduction reference for scheduling.
TensorView* redu_tv = reduction_tv;
// After canonicalization + reorder: [R, I]
const int64_t outer_reduce_axis = 0;
// Phase 2: Schedule TMA tensor with 2D TMA tiling
// Apply transforms to the TMA smem TV (now in [R, I] form after
// canonicalization propagation).
TensorView* tma_tv = tma_tvs[0];
// [R, I] -> [R/tma_tile_r, tma_tile_r, I]
tma_tv->split(outer_reduce_axis, tma_tile_r);
// [R/tma_tile_r, tma_tile_r, I] -> [R/tma_tile_r, tma_tile_r, I/tma_tile_i,
// tma_tile_i]
tma_tv->split(2, tma_tile_i);
// Split outer reduction for grid parallelization (BIDy)
// [R/tma_tile_r, tma_tile_r, I/tma_tile_i, tma_tile_i]
// -> [grdim, R', tma_tile_r, I/tma_tile_i, tma_tile_i]
// 0 1 2 3 4
tma_tv->split(0, grdim, false);
// Phase 3: Propagate TMA tiling to all tensors
TransformPropagator propagator(tma_tv);
MaxLogicalDomainInfoSpanningTree(tma_tv).traverse(&propagator);
// Phase 4: Parallelize TMA tensor
tma_tv->axis(0)->parallelize(ParallelType::BIDy);
tma_tv->axis(1)->parallelize(ParallelType::Serial);
tma_tv->axis(2)->parallelize(ParallelType::Bulk); // reduction tile
tma_tv->axis(3)->parallelize(ParallelType::BIDx);
tma_tv->axis(4)->parallelize(ParallelType::Bulk); // iteration tile
// Parallelize remaining TMA tvs to match
scheduler_utils::parallelizeAllLike(tma_tv, tma_tvs);
// Phase 5: Sub-split TMA tiles into thread dims
// Split tma_tile_i (axis 4) into [bdimx, iter_unroll]
redu_tv->split(4, iter_unroll_factor);
// Split tma_tile_r (axis 2) into [redu_unroll, bdimy]
redu_tv->split(2, bdimy);
// Now: [grdim, R', redu_unroll, bdimy, I/tma_tile_i, bdimx, iter_unroll]
// 0 1 2 3 4 5 6
// Phase 6: Parallelize reduction tensor
redu_tv->axis(0)->parallelize(ParallelType::BIDy);
redu_tv->axis(1)->parallelize(ParallelType::Serial);
redu_tv->axis(2)->parallelize(ParallelType::Unroll); // redu_unroll
redu_tv->axis(3)->parallelize(ParallelType::TIDy); // bdimy
redu_tv->axis(4)->parallelize(ParallelType::BIDx);
redu_tv->axis(5)->parallelize(ParallelType::TIDx); // bdimx
// Use Vectorize so it gets converted to Group for iterGroupedGridReduce
redu_tv->axis(6)->parallelize(ParallelType::Vectorize); // iter_unroll
// Phase 7: rFactor for grid reduction
// rFactor reduction axes that are not thread-parallelized
std::vector<int64_t> rfactor_axes;
for (int64_t i = 0; i < redu_tv->nDims(); i++) {
if (redu_tv->axis(i)->isReduction() && !redu_tv->axis(i)->isThread()) {
rfactor_axes.push_back(i);
}
}
TensorView* reference_tv = redu_tv;
if (!rfactor_axes.empty()) {
reference_tv = ir_utils::rFactorHelper(redu_tv, rfactor_axes);
}
// Phase 8: Propagate thread-level splits to non-TMA TVs
std::vector<TensorView*> non_tma_tvs =
ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()});
TransformPropagator non_tma_propagator(reference_tv);
SetSelector non_tma_selector({non_tma_tvs.begin(), non_tma_tvs.end()});
MaxLogicalDomainInfoSpanningTree(reference_tv, &non_tma_selector)
.traverse(&non_tma_propagator);
// Phase 9: Propagate parallelization with iter-grouped reduction
const bool use_iter_grouped_reduction = true;
if (reference_tv != redu_tv) {
reduction_scheduler_utils::propagateRFactor(
reference_tv, redu_tv, reduction_tvs);
non_tma_tvs =
ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()});
}
// Collect output TVs for unroll_vectorizable_cached_tvs
std::unordered_set<TensorView*> output_tvs;
for (auto output : fusion->outputs()) {
if (auto tv = dynamic_cast<TensorView*>(output)) {
output_tvs.insert(tv);
}
}
reduction_scheduler_utils::propagateParallelization(
redu_tv,
reference_tv,
/*is_unroll_or_vectorization=*/true,
use_iter_grouped_reduction,
reduction_tvs,
/*unroll_vectorizable_cached_tvs=*/output_tvs,
/*selected_tvs=*/non_tma_tvs);
// Phase 10: Inline and refine
inlineMost();
refineCachePolicy(fusion);
}
} // namespace outer_tma
} // namespace reduction
} // namespace nvfuser