-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathreduction.cpp
More file actions
383 lines (323 loc) · 12 KB
/
reduction.cpp
File metadata and controls
383 lines (323 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include "scheduler/reduction.h"
#include <ATen/cuda/CUDAContext.h>
#include "debug.h"
#include "instrumentation.h"
#include "ir/utils.h"
#include "scheduler/debug_utils.h"
#include "scheduler/reduction_non_tma.h"
#include "scheduler/reduction_outer_tma.h"
#include "scheduler/reduction_tma.h"
#include "scheduler/reduction_utils.h"
#include "scheduler/registry_utils.h"
#include "scheduler/runtime_info.h"
namespace nvfuser {
//! Check if the reduction heuristics apply in given fusion
bool ReductionScheduler::canScheduleCompileTime(Fusion* fusion) {
FUSER_PERF_SCOPE("ReductionScheduler::canScheduleCompileTime");
for (auto tv : fusion->allTvs()) {
if (tv->dtype() != DataType::Index &&
dataTypeSizeBit(tv->dtype()) % 8 != 0) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Does not support sub-byte data types.");
return false;
}
}
if (scheduler_utils::isResharding(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Fusion is resharding.");
return false;
}
// Needs at least one reduction to consider.
if (!ir_utils::hasAnyReductionOps(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "No reduction op to schedule");
return false;
}
if (ir_utils::filterByType<TensorView>(fusion->inputs()).empty()) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Scheduling not supported with no input");
return false;
}
// Check that inputs of all select/gather-like ops are fusion inputs
if (registry_utils::rejectScheduleForMemoryPromotion(
fusion, schedulerType())) {
return false;
}
auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
if (reduction_tvs.empty()) {
// Use pointwise logic
return false;
}
// Reject when output IDs are not covered by reference tv. Assuming reduction
// scheduler simply uses reduction_tvs[0] as the reference, if that changes,
// this needs to be changed. see issue
// https://github.com/NVIDIA/Fuser/issues/3811
scheduler_tools::DomainMap domain_map(fusion);
if (!domain_map.isValidReference(reduction_tvs[0], /*check_inputs=*/true)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"Output contains ID that's not scheduled by reference tv.");
return false;
}
if (registry_utils::hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}
if (!ir_utils::getReshapeOps(fusion).empty()) {
ComputeAtMap ca_map(fusion);
if (registry_utils::requiresForwardViewReplay(fusion, ca_map)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "Fusion requires view being reversible.");
return false;
}
// Reduction scheduler simply uses reduction_tvs[0] as the reference, if
// that changes, this needs to be changed.
if (registry_utils::reductionInterferingView(
fusion, ca_map, reduction_tvs[0])) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "View may interfere with reduction scheduling.");
return false;
}
}
// Make sure reduction axes are consistent through the fusion
auto reduction_ops = ir_utils::getAllTypesOfReductionOps(fusion);
if (reduction_ops.size() > 1) {
// Before examining the reduction axes want to quickly
// check the reductions have the same axis width
// to avoid building root domain map in easier cases
bool valid_axis_count = false;
size_t axis_count = 0;
auto reduction_root_size = [](TensorView* red_tv) {
size_t count = 0;
for (auto id : red_tv->getMaybeRootDomain()) {
if (!id->isBroadcast()) {
count++;
}
}
return count;
};
for (auto red : reduction_tvs) {
if (!valid_axis_count) {
valid_axis_count = true;
axis_count = reduction_root_size(red);
} else {
if (reduction_root_size(red) != axis_count) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"Inconsistent reduction root size: ",
red->toString(),
", expected: ",
axis_count);
return false;
}
}
}
// Use root domain map to check the reduction ops have the same axes
FusionGuard fg(fusion);
ComputeAtLogicalDomainMap logical_map;
logical_map.build(true);
// red_ops.size()>1 checked before
for (size_t it = 1; it < reduction_tvs.size(); it++) {
if (!registry_utils::checkPatternEquivalence(
reduction_tvs[it - 1], reduction_tvs[it], logical_map)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"Un-mapped multi-reduction: ",
reduction_tvs[it - 1]->toString(),
" and ",
reduction_tvs[it]->toString());
return false;
}
}
}
// Doesn't allow persistent kernels in this scheduler
auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion);
if (!persistent_buffer_info.persistent_buffers.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"need persistent buffers that reduction scheduler doesn't handle");
return false;
}
if (!registry_utils::SchedulerTopologyChecker::supportedPostReductionFusion(
fusion, reduction_tvs) ||
registry_utils::SchedulerTopologyChecker::hasPostReductionBCast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "has unsupported post reduction fusion");
return false;
}
if (registry_utils::SchedulerTopologyChecker::
hasGatherToBroadcastBeforeReduction(fusion, reduction_tvs)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "has unsupported gather-like ops before reduction");
return false;
}
return true;
}
bool ReductionScheduler::canScheduleRunTime(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicDataCache* data_cache) {
FUSER_PERF_SCOPE("ReductionScheduler::canScheduleRunTime");
return true;
}
namespace {
bool mayUseTma(
Fusion* fusion,
const reduction_scheduler_utils::FusionRuntimeProperties& props) {
auto dev_prop = at::cuda::getCurrentDeviceProperties();
if (dev_prop->major < 9) {
return false;
}
if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
return false;
}
// Require the reduction shape is 2D inner reduction: [I, R]
if (!props.fastest_dim_reduction ||
props.total_reduction_numel != props.inner_most_dimension_numel) {
return false;
}
int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes;
// Minimum TMA transfer size, below which it seems much slower than non-TMA.
uint64_t min_tma_bytes = 16384;
if (total_reduction_bytes < min_tma_bytes) {
return false;
}
// Require reduction dim fits into smem, until we add iteration over large
// reduction dim.
const int64_t smem_elems =
(static_cast<int64_t>(dev_prop->sharedMemPerBlockOptin) * 8) /
props.max_dtype_size_bit_for_vectorization;
if (props.inner_most_dimension_numel > smem_elems) {
return false;
}
// Smem check assumes only one input tensor.
if (props.n_tensor_inputs != 1) {
return false;
}
// Require that the innermost dim is contiguous.
if (props.vectorize_factor <= 1) {
return false;
}
uint64_t vect_bits =
props.vectorize_factor * props.max_dtype_size_bit_for_vectorization;
// TMA requires 16-byte alignment (128 bits) for memory transactions
if (vect_bits % 128 != 0) {
return false;
}
// Welford reductions not yet supported by TMA schedulers.
if (ir_utils::hasOpsOfType<WelfordOp>(fusion)) {
return false;
}
return true;
}
bool mayUseTmaOuter(
Fusion* fusion,
const reduction_scheduler_utils::FusionRuntimeProperties& props) {
auto dev_prop = at::cuda::getCurrentDeviceProperties();
if (dev_prop->major < 9) {
return false;
}
if (!scheduler_utils::inputsHaveContiguousInnerDim(fusion)) {
return false;
}
// Require outer reduction
if (props.fastest_dim_reduction) {
return false;
}
int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes;
// Minimum TMA transfer size
uint64_t min_tma_bytes = 16384;
if (total_reduction_bytes < min_tma_bytes) {
return false;
}
// Check that TMA tiles for all inputs fit in shared memory.
// The heuristic will shrink tiles to fit, but the minimum tile size is
// bdimy × bdimx (16 × 32). Reject if even minimum tiles don't fit.
// Reserve space for block reduction workspace and static smem overhead.
const int64_t min_tile_r = 16; // bdimy
const int64_t min_tile_i = 32; // bdimx
const int64_t threads_per_block = min_tile_i * min_tile_r;
int64_t smem_overhead =
alignSharedMemoryBytes(threads_per_block * dtype_bytes) +
kSharedMemoryAlignmentBytes;
int64_t min_smem_per_input = min_tile_r * min_tile_i * dtype_bytes;
int64_t min_total_smem =
min_smem_per_input * props.n_tensor_inputs + smem_overhead;
if (min_total_smem > (int64_t)dev_prop->sharedMemPerBlockOptin) {
return false;
}
// Require that the innermost dim is contiguous.
if (props.vectorize_factor <= 1) {
return false;
}
uint64_t vect_bits =
props.vectorize_factor * props.max_dtype_size_bit_for_vectorization;
// TMA requires 16-byte alignment (128 bits) for memory transactions
if (vect_bits % 128 != 0) {
return false;
}
// Welford reductions not yet supported by the TMA outer scheduler.
if (ir_utils::hasOpsOfType<WelfordOp>(fusion)) {
return false;
}
return true;
}
} // namespace
std::unique_ptr<HeuristicParams> ReductionScheduler::computeHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicDataCache* data_cache) {
FUSER_PERF_SCOPE("ReductionScheduler::computeHeuristics");
auto props = reduction_scheduler_utils::getFusionRuntimeProperties(
fusion, runtime_info, data_cache);
bool tma_enabled = isOptionEnabled(EnableOption::TmaReduction);
std::unique_ptr<HeuristicParams> rparams = nullptr;
// Try outer TMA scheduler for outer reductions
if (tma_enabled && mayUseTmaOuter(fusion, props)) {
rparams = reduction::outer_tma::getReductionHeuristics(
fusion, runtime_info, data_cache, props);
}
// Try inner TMA scheduler for inner reductions
if (rparams == nullptr && tma_enabled && mayUseTma(fusion, props)) {
rparams = reduction::tma::getReductionHeuristics(
fusion, runtime_info, data_cache, props);
}
// Fallback to non-TMA scheduler if TMA is not applicable
if (rparams == nullptr) {
rparams = reduction::non_tma::getReductionHeuristics(
fusion, runtime_info, data_cache, props);
}
NVF_ERROR(rparams != nullptr);
return rparams;
}
void ReductionScheduler::schedule(
Fusion* fusion,
const HeuristicParams* params) {
FUSER_PERF_SCOPE("ReductionScheduler::schedule");
if (auto* outer_tma_params =
dynamic_cast<const TmaOuterReductionParams*>(params)) {
reduction::outer_tma::scheduleReduction(fusion, outer_tma_params);
} else if (
auto* tma_params = dynamic_cast<const TmaInnerReductionParams*>(params)) {
reduction::tma::scheduleReduction(fusion, tma_params);
} else {
auto rparams = dynamic_cast<const ReductionParams*>(params);
NVF_ERROR(
rparams != nullptr,
"Incorrect parameters sent to ReductionScheduler::schedule",
params);
reduction::non_tma::scheduleReduction(fusion, rparams);
}
}
} // namespace nvfuser