Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "debug.h"
#include "instrumentation.h"
#include "ir/utils.h"
#include "scheduler/debug_utils.h"
#include "scheduler/reduction_non_tma.h"
#include "scheduler/reduction_outer_tma.h"
Expand Down Expand Up @@ -252,6 +253,11 @@ bool mayUseTma(
return false;
}

// Welford reductions not yet supported by TMA schedulers.
if (ir_utils::hasOpsOfType<WelfordOp>(fusion)) {
return false;
}

return true;
}

Expand Down Expand Up @@ -282,8 +288,20 @@ bool mayUseTmaOuter(
return false;
}

// TMA tile may exceed Smem size for multiple tensors.
if (props.n_tensor_inputs != 1) {
// Check that TMA tiles for all inputs fit in shared memory.
// The heuristic will shrink tiles to fit, but the minimum tile size is
// bdimy × bdimx (16 × 32). Reject if even minimum tiles don't fit.
// Reserve space for block reduction workspace and static smem overhead.
const int64_t min_tile_r = 16; // bdimy
const int64_t min_tile_i = 32; // bdimx
const int64_t threads_per_block = min_tile_i * min_tile_r;
int64_t smem_overhead =
alignSharedMemoryBytes(threads_per_block * dtype_bytes) +
kSharedMemoryAlignmentBytes;
int64_t min_smem_per_input = min_tile_r * min_tile_i * dtype_bytes;
int64_t min_total_smem =
min_smem_per_input * props.n_tensor_inputs + smem_overhead;
if (min_total_smem > (int64_t)dev_prop->sharedMemPerBlockOptin) {
return false;
}

Expand All @@ -300,6 +318,11 @@ bool mayUseTmaOuter(
return false;
}

// Welford reductions not yet supported by the TMA outer scheduler.
if (ir_utils::hasOpsOfType<WelfordOp>(fusion)) {
return false;
}

return true;
}
} // namespace
Expand Down
54 changes: 43 additions & 11 deletions csrc/scheduler/reduction_outer_tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <ATen/cuda/CUDAContext.h>

#include "ir/utils.h"
#include "scheduler/cache_policy_refiner.h"
#include "scheduler/reduction_utils.h"
#include "scheduler/runtime_info.h"
Expand All @@ -34,9 +35,32 @@ std::unique_ptr<TmaOuterReductionParams> getReductionHeuristics(
const int64_t bdimx = 32;
const int64_t bdimy = 16;

// TMA tile sizes. Unroll factors are derived from these and thread dims.
const int64_t tma_tile_i = 128;
const int64_t tma_tile_r = 128;
// Compute TMA tile sizes based on available shared memory budget.
// Each input tensor gets a tma_tile_r * tma_tile_i tile in smem.
// Reserve space for block reduction workspace and static smem.
auto dev_prop = at::cuda::getCurrentDeviceProperties();
const int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
const int64_t threads_per_block = bdimx * bdimy;
const int64_t smem_overhead =
alignSharedMemoryBytes(threads_per_block * dtype_bytes) +
kSharedMemoryAlignmentBytes;
const int64_t smem_bytes =
(int64_t)dev_prop->sharedMemPerBlockOptin - smem_overhead;
const int64_t n_inputs = std::max(props.n_tensor_inputs, (int64_t)1);
const int64_t smem_per_input = smem_bytes / n_inputs;

// Start with default tile sizes and shrink if they exceed the per-input
// smem budget. Tile sizes must remain multiples of thread block dims.
int64_t tma_tile_i = 128;
int64_t tma_tile_r = 128;
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_r > bdimy) {
tma_tile_r /= 2;
}
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_i > bdimx) {
tma_tile_i /= 2;
}

NVF_ERROR(tma_tile_i % bdimx == 0);
NVF_ERROR(tma_tile_r % bdimy == 0);
Expand Down Expand Up @@ -110,6 +134,10 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
// Reorder from [I, R] -> [R, I] for outer reduction pattern
reduction_tv->reorder({{0, 1}, {1, 0}});

// Propagate the canonicalized [R, I] merges to all tensors.
TransformPropagator canon_propagator(reduction_tv);
MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&canon_propagator);

// The reduction_tv already has a cacheBefore from cacheAndForkOutputs.
// Use reduction_tv directly as our reduction reference for scheduling.
TensorView* redu_tv = reduction_tv;
Expand All @@ -118,9 +146,8 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
const int64_t outer_reduce_axis = 0;

// Phase 2: Schedule TMA tensor with 2D TMA tiling
// Apply transforms to the TMA smem TV.
// We start from the TMA TV, which shares the same logical domain as the
// reduction TV's producer.
// Apply transforms to the TMA smem TV (now in [R, I] form after
// canonicalization propagation).
TensorView* tma_tv = tma_tvs[0];

// [R, I] -> [R/tma_tile_r, tma_tile_r, I]
Expand Down Expand Up @@ -166,8 +193,12 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
redu_tv->axis(3)->parallelize(ParallelType::TIDy); // bdimy
redu_tv->axis(4)->parallelize(ParallelType::BIDx);
redu_tv->axis(5)->parallelize(ParallelType::TIDx); // bdimx
// Use Vectorize so it gets converted to Group for iterGroupedGridReduce
redu_tv->axis(6)->parallelize(ParallelType::Vectorize); // iter_unroll

// Vectorize gets converted to Group for iterGroupedGridReduce.
// When iter_unroll_factor == 1, use Serial to avoid an invalid size-1
// Vectorize axis; this falls back to regular (non-grouped) grid reduction.
redu_tv->axis(6)->parallelize(
iter_unroll_factor > 1 ? ParallelType::Vectorize : ParallelType::Serial);

// Phase 7: rFactor for grid reduction
// rFactor reduction axes that are not thread-parallelized
Expand All @@ -180,7 +211,7 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {

TensorView* reference_tv = redu_tv;
if (!rfactor_axes.empty()) {
reference_tv = redu_tv->rFactor(rfactor_axes);
reference_tv = ir_utils::rFactorHelper(redu_tv, rfactor_axes);
}

// Phase 8: Propagate thread-level splits to non-TMA TVs
Expand All @@ -191,8 +222,9 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
MaxLogicalDomainInfoSpanningTree(reference_tv, &non_tma_selector)
.traverse(&non_tma_propagator);

// Phase 9: Propagate parallelization with iter-grouped reduction
const bool use_iter_grouped_reduction = true;
// Phase 9: Propagate parallelization with iter-grouped reduction.
// Iter-grouping requires iter_unroll_factor > 1 (Vectorize -> Group).
const bool use_iter_grouped_reduction = iter_unroll_factor > 1;

if (reference_tv != redu_tv) {
reduction_scheduler_utils::propagateRFactor(
Expand Down
96 changes: 64 additions & 32 deletions tests/cpp/test_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3099,7 +3099,8 @@ INSTANTIATE_TEST_SUITE_P(

// Test outer reduction with auto-scheduled TMA
using TmaOuterReductionTestParams =
std::tuple<int64_t, int64_t>; // <outer_size, iter_size>
std::tuple<int64_t, int64_t, int64_t, DataType>;
// <outer_size, iter_size, n_inputs, dtype>

class TmaOuterReductionTest
: public NVFuserFixtureParamTest<TmaOuterReductionTestParams> {
Expand All @@ -3115,6 +3116,7 @@ class TmaOuterReductionTest
bool expectOuterTmaUsed(
int64_t outer_size,
int64_t iter_size,
int64_t n_inputs,
int64_t dtype_bytes) {
uint64_t total_reduction_bytes = outer_size * dtype_bytes;
uint64_t min_tma_bytes = 16384;
Expand All @@ -3125,6 +3127,15 @@ class TmaOuterReductionTest
if ((iter_size * dtype_bytes) % 16 != 0) {
return false;
}
// Check minimum tile smem fits for all inputs, accounting for
// block reduction workspace + static smem overhead.
int64_t min_smem_per_input = 16 * 32 * dtype_bytes;
int64_t smem_overhead = ((512 * dtype_bytes + 127) & ~127) + 128;
auto* dev_prop = at::cuda::getCurrentDeviceProperties();
if (min_smem_per_input * n_inputs + smem_overhead >
(int64_t)dev_prop->sharedMemPerBlockOptin) {
return false;
}
return true;
}

Expand All @@ -3133,8 +3144,7 @@ class TmaOuterReductionTest
};

TEST_P(TmaOuterReductionTest, Sum) {
auto [outer_size, iter_size] = GetParam();
auto dtype = DataType::Float;
auto [outer_size, iter_size, n_inputs, dtype] = GetParam();
int64_t dtype_bytes = dataTypeSizeByte(dtype);

// TMA requires the stride (iter_size * dtype_bytes) to be 16-byte aligned
Expand All @@ -3143,56 +3153,78 @@ TEST_P(TmaOuterReductionTest, Sum) {
return;
}

std::vector<int64_t> shape = {outer_size, iter_size};

auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard fg(fusion_ptr.get());
Fusion& fusion = *fusion_ptr;

auto tv0 = makeContigTensor(2);
fusion.addInput(tv0);
auto tv1 = sum(tv0, {0}); // reduce along axis 0 (outer reduction)
fusion.addOutput(tv1);
// Create n_inputs tensors and sum them element-wise before reducing
std::vector<TensorView*> inputs;
for (int64_t i = 0; i < n_inputs; i++) {
auto tv = makeContigTensor(2, dtype);
fusion.addInput(tv);
inputs.push_back(tv);
}
auto accum = inputs[0];
for (int64_t i = 1; i < n_inputs; i++) {
accum = add(accum, inputs[i]);
}
auto reduced = sum(accum, {0});
fusion.addOutput(reduced);

auto unscheduled_fusion_copy = fusion;

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn(shape, options);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
std::vector<c10::IValue> aten_inputs;
for (int64_t i = 0; i < n_inputs; i++) {
aten_inputs.push_back(at::randn({outer_size, iter_size}, options));
}

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs({t0});
auto outputs = executor_cache.runFusionWithInputs(aten_inputs);

if (expectOuterTmaUsed(outer_size, iter_size, dtype_bytes)) {
if (expectOuterTmaUsed(outer_size, iter_size, n_inputs, dtype_bytes)) {
EXPECT_TRUE(tma_reduction_check::isOuterTmaParams(executor_cache))
<< "Expected outer TMA scheduler for outer_size=" << outer_size
<< " iter_size=" << iter_size;
<< " iter_size=" << iter_size << " n_inputs=" << n_inputs;
}

testValidate(&unscheduled_fusion_copy, outputs, {t0}, __LINE__, __FILE__, "");
testValidate(
&unscheduled_fusion_copy, outputs, aten_inputs, __LINE__, __FILE__, "");
}

INSTANTIATE_TEST_SUITE_P(
,
TmaOuterReductionTest,
testing::Combine(
testing::ValuesIn([] { // outer_size
std::vector<int64_t> vals;
for (int64_t v = 256; v <= 65536; v *= 4) {
vals.push_back(v);
}
return vals;
}()),
testing::ValuesIn([] { // iter_size
std::vector<int64_t> vals;
for (int64_t v = 256; v <= 65536; v *= 4) {
vals.push_back(v);
}
return vals;
}())),
testing::Values(
// Size sweep with single input, float
TmaOuterReductionTestParams{256, 256, 1, DataType::Float},
TmaOuterReductionTestParams{256, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{256, 65536, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 256, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 65536, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 256, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 65536, 1, DataType::Float},
// Multi-input (exercises smem budget / tile shrinking)
TmaOuterReductionTestParams{4096, 1024, 2, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 3, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 5, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 2, DataType::Half},
TmaOuterReductionTestParams{4096, 1024, 3, DataType::Half},
TmaOuterReductionTestParams{4096, 1024, 5, DataType::Half}),
([](const testing::TestParamInfo<TmaOuterReductionTestParams>& info) {
auto [outer_size, iter_size] = info.param;
return "outer_" + std::to_string(outer_size) + "_iter_" +
auto [outer_size, iter_size, n_inputs, dtype] = info.param;
std::string name = "outer_" + std::to_string(outer_size) + "_iter_" +
std::to_string(iter_size);
if (n_inputs > 1) {
name += "_" + std::to_string(n_inputs) + "inputs";
}
if (dtype != DataType::Float) {
name += "_half";
}
return name;
}));

} // namespace nvfuser
Loading