|
| 1 | +// clang-format off |
| 2 | +/* |
| 3 | + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. |
| 4 | + * All rights reserved. |
| 5 | + * SPDX-License-Identifier: BSD-3-Clause |
| 6 | + */ |
| 7 | +// clang-format on |
| 8 | + |
| 9 | +#include "scheduler/reduction_outer_tma.h" |
| 10 | + |
| 11 | +#include <ATen/cuda/CUDAContext.h> |
| 12 | + |
| 13 | +#include "scheduler/cache_policy_refiner.h" |
| 14 | +#include "scheduler/reduction_utils.h" |
| 15 | +#include "scheduler/runtime_info.h" |
| 16 | +#include "scheduler/tools/inlining.h" |
| 17 | +#include "scheduler/utils.h" |
| 18 | + |
| 19 | +namespace nvfuser { |
| 20 | +namespace reduction { |
| 21 | +namespace outer_tma { |
| 22 | + |
| 23 | +std::unique_ptr<TmaOuterReductionParams> getReductionHeuristics( |
| 24 | + Fusion* fusion, |
| 25 | + SchedulerRuntimeInfo& runtime_info, |
| 26 | + HeuristicDataCache* data_cache, |
| 27 | + const reduction_scheduler_utils::FusionRuntimeProperties& props) { |
| 28 | + FusionGuard fg(fusion); |
| 29 | + |
| 30 | + // TODO: These heuristics are stubbed out based on the manual test |
| 31 | + // (TmaOuterReductionManualTest::Basic). They need proper tuning. |
| 32 | + |
| 33 | + // 2D thread block: TIDx covers iteration, TIDy covers reduction |
| 34 | + const int64_t bdimx = 32; |
| 35 | + const int64_t bdimy = 16; |
| 36 | + |
| 37 | + // TMA tile sizes. Unroll factors are derived from these and thread dims. |
| 38 | + const int64_t tma_tile_i = 128; |
| 39 | + const int64_t tma_tile_r = 128; |
| 40 | + |
| 41 | + NVF_ERROR(tma_tile_i % bdimx == 0); |
| 42 | + NVF_ERROR(tma_tile_r % bdimy == 0); |
| 43 | + |
| 44 | + const int64_t iter_unroll_factor = tma_tile_i / bdimx; |
| 45 | + |
| 46 | + // Grid dimension for parallelizing the outer reduction across CTAs. |
| 47 | + // Modeled after the manual test: clamp lastPow2(outer_size / 256) to [1, 8]. |
| 48 | + const int64_t outer_size = props.total_reduction_numel; |
| 49 | + int64_t grdim = std::max<int64_t>( |
| 50 | + 1, std::min<int64_t>(8, scheduler_utils::lastPow2(outer_size / 256))); |
| 51 | + |
| 52 | + auto params = std::make_unique<TmaOuterReductionParams>(); |
| 53 | + params->bdimx = bdimx; |
| 54 | + params->bdimy = bdimy; |
| 55 | + params->tma_tile_i = tma_tile_i; |
| 56 | + params->tma_tile_r = tma_tile_r; |
| 57 | + params->iter_unroll_factor = iter_unroll_factor; |
| 58 | + params->grdim = grdim; |
| 59 | + |
| 60 | + params->tag = "Outer Reduction TMA heuristics"; |
| 61 | + params->cparams.index_type = runtime_info.getIndexType(); |
| 62 | + |
| 63 | + return params; |
| 64 | +} |
| 65 | + |
| 66 | +void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) { |
| 67 | + FusionGuard fg(fusion); |
| 68 | + |
| 69 | + const int64_t bdimx = rparams->bdimx; |
| 70 | + const int64_t bdimy = rparams->bdimy; |
| 71 | + const int64_t tma_tile_i = rparams->tma_tile_i; |
| 72 | + const int64_t tma_tile_r = rparams->tma_tile_r; |
| 73 | + const int64_t iter_unroll_factor = rparams->iter_unroll_factor; |
| 74 | + const int64_t grdim = rparams->grdim; |
| 75 | + |
| 76 | + NVF_ERROR(tma_tile_i % bdimx == 0); |
| 77 | + NVF_ERROR(tma_tile_r % bdimy == 0); |
| 78 | + |
| 79 | + // Phase 1: Cache inputs into shared memory via TMA |
| 80 | + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); |
| 81 | + |
| 82 | + scheduler_utils::clearMemorySpace(fusion); |
| 83 | + |
| 84 | + scheduler_utils::cacheAndForkOutputs(fusion, true); |
| 85 | + |
| 86 | + scheduler_utils::prepareForMemoryTypePromotion(fusion); |
| 87 | + |
| 88 | + std::vector<TensorView*> tma_tvs; |
| 89 | + for (auto [tv, input_idx] : cached_inputs) { |
| 90 | + if (auto load_op = dynamic_cast<LoadStoreOp*>(tv->definition())) { |
| 91 | + load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile); |
| 92 | + tv->setMemoryType(MemoryType::Shared); |
| 93 | + tma_tvs.push_back(tv); |
| 94 | + } |
| 95 | + } |
| 96 | + NVF_ERROR(!tma_tvs.empty()); |
| 97 | + |
| 98 | + auto reduction_tvs = scheduler_utils::getReductionTvs(fusion); |
| 99 | + NVF_ERROR(!reduction_tvs.empty()); |
| 100 | + TensorView* reduction_tv = reduction_tvs.at(0); |
| 101 | + |
| 102 | + // canonicalizeReduction with schedule_3d=false merges into [I, R] form. |
| 103 | + // For outer reduction, we want [R, I], so we reorder after canonicalization. |
| 104 | + auto dim_analysis = |
| 105 | + scheduler_utils::canonicalizeReduction(fusion, reduction_tv, false); |
| 106 | + bool has_iter_axis = dim_analysis.first; |
| 107 | + bool has_red_axis = dim_analysis.second; |
| 108 | + NVF_ERROR(has_iter_axis && has_red_axis); |
| 109 | + |
| 110 | + // Reorder from [I, R] -> [R, I] for outer reduction pattern |
| 111 | + reduction_tv->reorder({{0, 1}, {1, 0}}); |
| 112 | + |
| 113 | + // The reduction_tv already has a cacheBefore from cacheAndForkOutputs. |
| 114 | + // Use reduction_tv directly as our reduction reference for scheduling. |
| 115 | + TensorView* redu_tv = reduction_tv; |
| 116 | + |
| 117 | + // After canonicalization + reorder: [R, I] |
| 118 | + const int64_t outer_reduce_axis = 0; |
| 119 | + |
| 120 | + // Phase 2: Schedule TMA tensor with 2D TMA tiling |
| 121 | + // Apply transforms to the TMA smem TV. |
| 122 | + // We start from the TMA TV, which shares the same logical domain as the |
| 123 | + // reduction TV's producer. |
| 124 | + TensorView* tma_tv = tma_tvs[0]; |
| 125 | + |
| 126 | + // [R, I] -> [R/tma_tile_r, tma_tile_r, I] |
| 127 | + tma_tv->split(outer_reduce_axis, tma_tile_r); |
| 128 | + |
| 129 | + // [R/tma_tile_r, tma_tile_r, I] -> [R/tma_tile_r, tma_tile_r, I/tma_tile_i, |
| 130 | + // tma_tile_i] |
| 131 | + tma_tv->split(2, tma_tile_i); |
| 132 | + |
| 133 | + // Split outer reduction for grid parallelization (BIDy) |
| 134 | + // [R/tma_tile_r, tma_tile_r, I/tma_tile_i, tma_tile_i] |
| 135 | + // -> [grdim, R', tma_tile_r, I/tma_tile_i, tma_tile_i] |
| 136 | + // 0 1 2 3 4 |
| 137 | + tma_tv->split(0, grdim, false); |
| 138 | + |
| 139 | + // Phase 3: Propagate TMA tiling to all tensors |
| 140 | + TransformPropagator propagator(tma_tv); |
| 141 | + MaxLogicalDomainInfoSpanningTree(tma_tv).traverse(&propagator); |
| 142 | + |
| 143 | + // Phase 4: Parallelize TMA tensor |
| 144 | + tma_tv->axis(0)->parallelize(ParallelType::BIDy); |
| 145 | + tma_tv->axis(1)->parallelize(ParallelType::Serial); |
| 146 | + tma_tv->axis(2)->parallelize(ParallelType::Bulk); // reduction tile |
| 147 | + tma_tv->axis(3)->parallelize(ParallelType::BIDx); |
| 148 | + tma_tv->axis(4)->parallelize(ParallelType::Bulk); // iteration tile |
| 149 | + |
| 150 | + // Parallelize remaining TMA tvs to match |
| 151 | + scheduler_utils::parallelizeAllLike(tma_tv, tma_tvs); |
| 152 | + |
| 153 | + // Phase 5: Sub-split TMA tiles into thread dims |
| 154 | + // Split tma_tile_i (axis 4) into [bdimx, iter_unroll] |
| 155 | + redu_tv->split(4, iter_unroll_factor); |
| 156 | + |
| 157 | + // Split tma_tile_r (axis 2) into [redu_unroll, bdimy] |
| 158 | + redu_tv->split(2, bdimy); |
| 159 | + // Now: [grdim, R', redu_unroll, bdimy, I/tma_tile_i, bdimx, iter_unroll] |
| 160 | + // 0 1 2 3 4 5 6 |
| 161 | + |
| 162 | + // Phase 6: Parallelize reduction tensor |
| 163 | + redu_tv->axis(0)->parallelize(ParallelType::BIDy); |
| 164 | + redu_tv->axis(1)->parallelize(ParallelType::Serial); |
| 165 | + redu_tv->axis(2)->parallelize(ParallelType::Unroll); // redu_unroll |
| 166 | + redu_tv->axis(3)->parallelize(ParallelType::TIDy); // bdimy |
| 167 | + redu_tv->axis(4)->parallelize(ParallelType::BIDx); |
| 168 | + redu_tv->axis(5)->parallelize(ParallelType::TIDx); // bdimx |
| 169 | + // Use Vectorize so it gets converted to Group for iterGroupedGridReduce |
| 170 | + redu_tv->axis(6)->parallelize(ParallelType::Vectorize); // iter_unroll |
| 171 | + |
| 172 | + // Phase 7: rFactor for grid reduction |
| 173 | + // rFactor reduction axes that are not thread-parallelized |
| 174 | + std::vector<int64_t> rfactor_axes; |
| 175 | + for (int64_t i = 0; i < redu_tv->nDims(); i++) { |
| 176 | + if (redu_tv->axis(i)->isReduction() && !redu_tv->axis(i)->isThread()) { |
| 177 | + rfactor_axes.push_back(i); |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + TensorView* reference_tv = redu_tv; |
| 182 | + if (!rfactor_axes.empty()) { |
| 183 | + reference_tv = redu_tv->rFactor(rfactor_axes); |
| 184 | + } |
| 185 | + |
| 186 | + // Phase 8: Propagate thread-level splits to non-TMA TVs |
| 187 | + std::vector<TensorView*> non_tma_tvs = |
| 188 | + ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()}); |
| 189 | + TransformPropagator non_tma_propagator(reference_tv); |
| 190 | + SetSelector non_tma_selector({non_tma_tvs.begin(), non_tma_tvs.end()}); |
| 191 | + MaxLogicalDomainInfoSpanningTree(reference_tv, &non_tma_selector) |
| 192 | + .traverse(&non_tma_propagator); |
| 193 | + |
| 194 | + // Phase 9: Propagate parallelization with iter-grouped reduction |
| 195 | + const bool use_iter_grouped_reduction = true; |
| 196 | + |
| 197 | + if (reference_tv != redu_tv) { |
| 198 | + reduction_scheduler_utils::propagateRFactor( |
| 199 | + reference_tv, redu_tv, reduction_tvs); |
| 200 | + non_tma_tvs = |
| 201 | + ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()}); |
| 202 | + } |
| 203 | + |
| 204 | + // Collect output TVs for unroll_vectorizable_cached_tvs |
| 205 | + std::unordered_set<TensorView*> output_tvs; |
| 206 | + for (auto output : fusion->outputs()) { |
| 207 | + if (auto tv = dynamic_cast<TensorView*>(output)) { |
| 208 | + output_tvs.insert(tv); |
| 209 | + } |
| 210 | + } |
| 211 | + |
| 212 | + reduction_scheduler_utils::propagateParallelization( |
| 213 | + redu_tv, |
| 214 | + reference_tv, |
| 215 | + /*is_unroll_or_vectorization=*/true, |
| 216 | + use_iter_grouped_reduction, |
| 217 | + reduction_tvs, |
| 218 | + /*unroll_vectorizable_cached_tvs=*/output_tvs, |
| 219 | + /*selected_tvs=*/non_tma_tvs); |
| 220 | + |
| 221 | + // Phase 10: Inline and refine |
| 222 | + inlineMost(); |
| 223 | + |
| 224 | + refineCachePolicy(fusion); |
| 225 | +} |
| 226 | + |
| 227 | +} // namespace outer_tma |
| 228 | +} // namespace reduction |
| 229 | +} // namespace nvfuser |
0 commit comments