diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 831b2ae80b3..0d9164b1fac 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -11,6 +11,7 @@ #include "debug.h" #include "instrumentation.h" +#include "ir/utils.h" #include "scheduler/debug_utils.h" #include "scheduler/reduction_non_tma.h" #include "scheduler/reduction_outer_tma.h" @@ -252,6 +253,11 @@ bool mayUseTma( return false; } + // Welford reductions not yet supported by TMA schedulers. + if (ir_utils::hasOpsOfType(fusion)) { + return false; + } + return true; } @@ -282,8 +288,20 @@ bool mayUseTmaOuter( return false; } - // TMA tile may exceed Smem size for multiple tensors. - if (props.n_tensor_inputs != 1) { + // Check that TMA tiles for all inputs fit in shared memory. + // The heuristic will shrink tiles to fit, but the minimum tile size is + // bdimy × bdimx (16 × 32). Reject if even minimum tiles don't fit. + // Reserve space for block reduction workspace and static smem overhead. + const int64_t min_tile_r = 16; // bdimy + const int64_t min_tile_i = 32; // bdimx + const int64_t threads_per_block = min_tile_i * min_tile_r; + int64_t smem_overhead = + alignSharedMemoryBytes(threads_per_block * dtype_bytes) + + kSharedMemoryAlignmentBytes; + int64_t min_smem_per_input = min_tile_r * min_tile_i * dtype_bytes; + int64_t min_total_smem = + min_smem_per_input * props.n_tensor_inputs + smem_overhead; + if (min_total_smem > (int64_t)dev_prop->sharedMemPerBlockOptin) { return false; } @@ -300,6 +318,11 @@ bool mayUseTmaOuter( return false; } + // Welford reductions not yet supported by the TMA outer scheduler. + if (ir_utils::hasOpsOfType(fusion)) { + return false; + } + return true; } } // namespace diff --git a/csrc/scheduler/reduction_outer_tma.cpp b/csrc/scheduler/reduction_outer_tma.cpp index 548d386bee2..070e4be48ee 100644 --- a/csrc/scheduler/reduction_outer_tma.cpp +++ b/csrc/scheduler/reduction_outer_tma.cpp @@ -10,6 +10,7 @@ #include +#include "ir/utils.h" #include "scheduler/cache_policy_refiner.h" #include "scheduler/reduction_utils.h" #include "scheduler/runtime_info.h" @@ -34,9 +35,32 @@ std::unique_ptr getReductionHeuristics( const int64_t bdimx = 32; const int64_t bdimy = 16; - // TMA tile sizes. Unroll factors are derived from these and thread dims. - const int64_t tma_tile_i = 128; - const int64_t tma_tile_r = 128; + // Compute TMA tile sizes based on available shared memory budget. + // Each input tensor gets a tma_tile_r * tma_tile_i tile in smem. + // Reserve space for block reduction workspace and static smem. + auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; + const int64_t threads_per_block = bdimx * bdimy; + const int64_t smem_overhead = + alignSharedMemoryBytes(threads_per_block * dtype_bytes) + + kSharedMemoryAlignmentBytes; + const int64_t smem_bytes = + (int64_t)dev_prop->sharedMemPerBlockOptin - smem_overhead; + const int64_t n_inputs = std::max(props.n_tensor_inputs, (int64_t)1); + const int64_t smem_per_input = smem_bytes / n_inputs; + + // Start with default tile sizes and shrink if they exceed the per-input + // smem budget. Tile sizes must remain multiples of thread block dims. + int64_t tma_tile_i = 128; + int64_t tma_tile_r = 128; + while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input && + tma_tile_r > bdimy) { + tma_tile_r /= 2; + } + while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input && + tma_tile_i > bdimx) { + tma_tile_i /= 2; + } NVF_ERROR(tma_tile_i % bdimx == 0); NVF_ERROR(tma_tile_r % bdimy == 0); @@ -110,6 +134,10 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) { // Reorder from [I, R] -> [R, I] for outer reduction pattern reduction_tv->reorder({{0, 1}, {1, 0}}); + // Propagate the canonicalized [R, I] merges to all tensors. + TransformPropagator canon_propagator(reduction_tv); + MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&canon_propagator); + // The reduction_tv already has a cacheBefore from cacheAndForkOutputs. // Use reduction_tv directly as our reduction reference for scheduling. TensorView* redu_tv = reduction_tv; @@ -118,9 +146,8 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) { const int64_t outer_reduce_axis = 0; // Phase 2: Schedule TMA tensor with 2D TMA tiling - // Apply transforms to the TMA smem TV. - // We start from the TMA TV, which shares the same logical domain as the - // reduction TV's producer. + // Apply transforms to the TMA smem TV (now in [R, I] form after + // canonicalization propagation). TensorView* tma_tv = tma_tvs[0]; // [R, I] -> [R/tma_tile_r, tma_tile_r, I] @@ -166,8 +193,12 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) { redu_tv->axis(3)->parallelize(ParallelType::TIDy); // bdimy redu_tv->axis(4)->parallelize(ParallelType::BIDx); redu_tv->axis(5)->parallelize(ParallelType::TIDx); // bdimx - // Use Vectorize so it gets converted to Group for iterGroupedGridReduce - redu_tv->axis(6)->parallelize(ParallelType::Vectorize); // iter_unroll + + // Vectorize gets converted to Group for iterGroupedGridReduce. + // When iter_unroll_factor == 1, use Serial to avoid an invalid size-1 + // Vectorize axis; this falls back to regular (non-grouped) grid reduction. + redu_tv->axis(6)->parallelize( + iter_unroll_factor > 1 ? ParallelType::Vectorize : ParallelType::Serial); // Phase 7: rFactor for grid reduction // rFactor reduction axes that are not thread-parallelized @@ -180,7 +211,7 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) { TensorView* reference_tv = redu_tv; if (!rfactor_axes.empty()) { - reference_tv = redu_tv->rFactor(rfactor_axes); + reference_tv = ir_utils::rFactorHelper(redu_tv, rfactor_axes); } // Phase 8: Propagate thread-level splits to non-TMA TVs @@ -191,8 +222,9 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) { MaxLogicalDomainInfoSpanningTree(reference_tv, &non_tma_selector) .traverse(&non_tma_propagator); - // Phase 9: Propagate parallelization with iter-grouped reduction - const bool use_iter_grouped_reduction = true; + // Phase 9: Propagate parallelization with iter-grouped reduction. + // Iter-grouping requires iter_unroll_factor > 1 (Vectorize -> Group). + const bool use_iter_grouped_reduction = iter_unroll_factor > 1; if (reference_tv != redu_tv) { reduction_scheduler_utils::propagateRFactor( diff --git a/tests/cpp/test_reduction.cpp b/tests/cpp/test_reduction.cpp index 9ce8e8b4d69..04e2cd56433 100644 --- a/tests/cpp/test_reduction.cpp +++ b/tests/cpp/test_reduction.cpp @@ -3099,7 +3099,8 @@ INSTANTIATE_TEST_SUITE_P( // Test outer reduction with auto-scheduled TMA using TmaOuterReductionTestParams = - std::tuple; // + std::tuple; +// class TmaOuterReductionTest : public NVFuserFixtureParamTest { @@ -3115,6 +3116,7 @@ class TmaOuterReductionTest bool expectOuterTmaUsed( int64_t outer_size, int64_t iter_size, + int64_t n_inputs, int64_t dtype_bytes) { uint64_t total_reduction_bytes = outer_size * dtype_bytes; uint64_t min_tma_bytes = 16384; @@ -3125,6 +3127,15 @@ class TmaOuterReductionTest if ((iter_size * dtype_bytes) % 16 != 0) { return false; } + // Check minimum tile smem fits for all inputs, accounting for + // block reduction workspace + static smem overhead. + int64_t min_smem_per_input = 16 * 32 * dtype_bytes; + int64_t smem_overhead = ((512 * dtype_bytes + 127) & ~127) + 128; + auto* dev_prop = at::cuda::getCurrentDeviceProperties(); + if (min_smem_per_input * n_inputs + smem_overhead > + (int64_t)dev_prop->sharedMemPerBlockOptin) { + return false; + } return true; } @@ -3133,8 +3144,7 @@ class TmaOuterReductionTest }; TEST_P(TmaOuterReductionTest, Sum) { - auto [outer_size, iter_size] = GetParam(); - auto dtype = DataType::Float; + auto [outer_size, iter_size, n_inputs, dtype] = GetParam(); int64_t dtype_bytes = dataTypeSizeByte(dtype); // TMA requires the stride (iter_size * dtype_bytes) to be 16-byte aligned @@ -3143,56 +3153,78 @@ TEST_P(TmaOuterReductionTest, Sum) { return; } - std::vector shape = {outer_size, iter_size}; - auto fusion_ptr = std::make_unique(); FusionGuard fg(fusion_ptr.get()); Fusion& fusion = *fusion_ptr; - auto tv0 = makeContigTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {0}); // reduce along axis 0 (outer reduction) - fusion.addOutput(tv1); + // Create n_inputs tensors and sum them element-wise before reducing + std::vector inputs; + for (int64_t i = 0; i < n_inputs; i++) { + auto tv = makeContigTensor(2, dtype); + fusion.addInput(tv); + inputs.push_back(tv); + } + auto accum = inputs[0]; + for (int64_t i = 1; i < n_inputs; i++) { + accum = add(accum, inputs[i]); + } + auto reduced = sum(accum, {0}); + fusion.addOutput(reduced); auto unscheduled_fusion_copy = fusion; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn(shape, options); + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + std::vector aten_inputs; + for (int64_t i = 0; i < n_inputs; i++) { + aten_inputs.push_back(at::randn({outer_size, iter_size}, options)); + } FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs({t0}); + auto outputs = executor_cache.runFusionWithInputs(aten_inputs); - if (expectOuterTmaUsed(outer_size, iter_size, dtype_bytes)) { + if (expectOuterTmaUsed(outer_size, iter_size, n_inputs, dtype_bytes)) { EXPECT_TRUE(tma_reduction_check::isOuterTmaParams(executor_cache)) << "Expected outer TMA scheduler for outer_size=" << outer_size - << " iter_size=" << iter_size; + << " iter_size=" << iter_size << " n_inputs=" << n_inputs; } - testValidate(&unscheduled_fusion_copy, outputs, {t0}, __LINE__, __FILE__, ""); + testValidate( + &unscheduled_fusion_copy, outputs, aten_inputs, __LINE__, __FILE__, ""); } INSTANTIATE_TEST_SUITE_P( , TmaOuterReductionTest, - testing::Combine( - testing::ValuesIn([] { // outer_size - std::vector vals; - for (int64_t v = 256; v <= 65536; v *= 4) { - vals.push_back(v); - } - return vals; - }()), - testing::ValuesIn([] { // iter_size - std::vector vals; - for (int64_t v = 256; v <= 65536; v *= 4) { - vals.push_back(v); - } - return vals; - }())), + testing::Values( + // Size sweep with single input, float + TmaOuterReductionTestParams{256, 256, 1, DataType::Float}, + TmaOuterReductionTestParams{256, 4096, 1, DataType::Float}, + TmaOuterReductionTestParams{256, 65536, 1, DataType::Float}, + TmaOuterReductionTestParams{4096, 256, 1, DataType::Float}, + TmaOuterReductionTestParams{4096, 4096, 1, DataType::Float}, + TmaOuterReductionTestParams{4096, 65536, 1, DataType::Float}, + TmaOuterReductionTestParams{65536, 256, 1, DataType::Float}, + TmaOuterReductionTestParams{65536, 4096, 1, DataType::Float}, + TmaOuterReductionTestParams{65536, 65536, 1, DataType::Float}, + // Multi-input (exercises smem budget / tile shrinking) + TmaOuterReductionTestParams{4096, 1024, 2, DataType::Float}, + TmaOuterReductionTestParams{4096, 1024, 3, DataType::Float}, + TmaOuterReductionTestParams{4096, 1024, 5, DataType::Float}, + TmaOuterReductionTestParams{4096, 1024, 2, DataType::Half}, + TmaOuterReductionTestParams{4096, 1024, 3, DataType::Half}, + TmaOuterReductionTestParams{4096, 1024, 5, DataType::Half}), ([](const testing::TestParamInfo& info) { - auto [outer_size, iter_size] = info.param; - return "outer_" + std::to_string(outer_size) + "_iter_" + + auto [outer_size, iter_size, n_inputs, dtype] = info.param; + std::string name = "outer_" + std::to_string(outer_size) + "_iter_" + std::to_string(iter_size); + if (n_inputs > 1) { + name += "_" + std::to_string(n_inputs) + "inputs"; + } + if (dtype != DataType::Float) { + name += "_half"; + } + return name; })); } // namespace nvfuser