Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "debug.h"
#include "instrumentation.h"
#include "ir/utils.h"
#include "scheduler/debug_utils.h"
#include "scheduler/reduction_non_tma.h"
#include "scheduler/reduction_outer_tma.h"
Expand Down Expand Up @@ -252,6 +253,11 @@ bool mayUseTma(
return false;
}

// Welford reductions not yet supported by TMA schedulers.
if (ir_utils::hasOpsOfType<WelfordOp>(fusion)) {
return false;
}

return true;
}

Expand Down Expand Up @@ -282,8 +288,15 @@ bool mayUseTmaOuter(
return false;
}

// TMA tile may exceed Smem size for multiple tensors.
if (props.n_tensor_inputs != 1) {
// Check that TMA tiles for all inputs fit in shared memory.
// The heuristic will shrink tiles to fit, but the minimum tile size is
// bdimy × bdimx (16 × 32). Reject if even minimum tiles don't fit.
const int64_t min_tile_r = 16; // bdimy
const int64_t min_tile_i = 32; // bdimx
int64_t min_smem_per_input = min_tile_r * min_tile_i *
(props.max_dtype_size_bit_for_vectorization / 8);
int64_t min_total_smem = min_smem_per_input * props.n_tensor_inputs;
if (min_total_smem > (int64_t)dev_prop->sharedMemPerBlockOptin) {
return false;
}

Expand All @@ -300,6 +313,11 @@ bool mayUseTmaOuter(
return false;
}

// Welford reductions not yet supported by the TMA outer scheduler.
if (ir_utils::hasOpsOfType<WelfordOp>(fusion)) {
return false;
}

return true;
}
} // namespace
Expand Down
36 changes: 29 additions & 7 deletions csrc/scheduler/reduction_outer_tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <ATen/cuda/CUDAContext.h>

#include "ir/utils.h"
#include "scheduler/cache_policy_refiner.h"
#include "scheduler/reduction_utils.h"
#include "scheduler/runtime_info.h"
Expand All @@ -34,9 +35,26 @@ std::unique_ptr<TmaOuterReductionParams> getReductionHeuristics(
const int64_t bdimx = 32;
const int64_t bdimy = 16;

// TMA tile sizes. Unroll factors are derived from these and thread dims.
const int64_t tma_tile_i = 128;
const int64_t tma_tile_r = 128;
// Compute TMA tile sizes based on available shared memory budget.
// Each input tensor gets a tma_tile_r * tma_tile_i tile in smem.
auto dev_prop = at::cuda::getCurrentDeviceProperties();
const int64_t smem_bytes = (int64_t)dev_prop->sharedMemPerBlockOptin;
const int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
const int64_t n_inputs = std::max(props.n_tensor_inputs, (int64_t)1);
const int64_t smem_per_input = smem_bytes / n_inputs;

// Start with default tile sizes and shrink if they exceed the per-input
// smem budget. Tile sizes must remain multiples of thread block dims.
int64_t tma_tile_i = 128;
int64_t tma_tile_r = 128;
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_r > bdimy) {
tma_tile_r /= 2;
}
while (tma_tile_r * tma_tile_i * dtype_bytes > smem_per_input &&
tma_tile_i > bdimx) {
tma_tile_i /= 2;
}

NVF_ERROR(tma_tile_i % bdimx == 0);
NVF_ERROR(tma_tile_r % bdimy == 0);
Expand Down Expand Up @@ -110,6 +128,10 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
// Reorder from [I, R] -> [R, I] for outer reduction pattern
reduction_tv->reorder({{0, 1}, {1, 0}});

// Propagate the canonicalized [R, I] merges to all tensors.
TransformPropagator canon_propagator(reduction_tv);
MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&canon_propagator);

// The reduction_tv already has a cacheBefore from cacheAndForkOutputs.
// Use reduction_tv directly as our reduction reference for scheduling.
TensorView* redu_tv = reduction_tv;
Expand All @@ -118,9 +140,8 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
const int64_t outer_reduce_axis = 0;

// Phase 2: Schedule TMA tensor with 2D TMA tiling
// Apply transforms to the TMA smem TV.
// We start from the TMA TV, which shares the same logical domain as the
// reduction TV's producer.
// Apply transforms to the TMA smem TV (now in [R, I] form after
// canonicalization propagation).
TensorView* tma_tv = tma_tvs[0];

// [R, I] -> [R/tma_tile_r, tma_tile_r, I]
Expand Down Expand Up @@ -166,6 +187,7 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {
redu_tv->axis(3)->parallelize(ParallelType::TIDy); // bdimy
redu_tv->axis(4)->parallelize(ParallelType::BIDx);
redu_tv->axis(5)->parallelize(ParallelType::TIDx); // bdimx

// Use Vectorize so it gets converted to Group for iterGroupedGridReduce
redu_tv->axis(6)->parallelize(ParallelType::Vectorize); // iter_unroll

Expand All @@ -180,7 +202,7 @@ void scheduleReduction(Fusion* fusion, const TmaOuterReductionParams* rparams) {

TensorView* reference_tv = redu_tv;
if (!rfactor_axes.empty()) {
reference_tv = redu_tv->rFactor(rfactor_axes);
reference_tv = ir_utils::rFactorHelper(redu_tv, rfactor_axes);
}

// Phase 8: Propagate thread-level splits to non-TMA TVs
Expand Down
94 changes: 62 additions & 32 deletions tests/cpp/test_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3099,7 +3099,8 @@ INSTANTIATE_TEST_SUITE_P(

// Test outer reduction with auto-scheduled TMA
using TmaOuterReductionTestParams =
std::tuple<int64_t, int64_t>; // <outer_size, iter_size>
std::tuple<int64_t, int64_t, int64_t, DataType>;
// <outer_size, iter_size, n_inputs, dtype>

class TmaOuterReductionTest
: public NVFuserFixtureParamTest<TmaOuterReductionTestParams> {
Expand All @@ -3115,6 +3116,7 @@ class TmaOuterReductionTest
bool expectOuterTmaUsed(
int64_t outer_size,
int64_t iter_size,
int64_t n_inputs,
int64_t dtype_bytes) {
uint64_t total_reduction_bytes = outer_size * dtype_bytes;
uint64_t min_tma_bytes = 16384;
Expand All @@ -3125,6 +3127,13 @@ class TmaOuterReductionTest
if ((iter_size * dtype_bytes) % 16 != 0) {
return false;
}
// Check minimum tile smem fits for all inputs
int64_t min_smem_per_input = 16 * 32 * dtype_bytes;
auto* dev_prop = at::cuda::getCurrentDeviceProperties();
if (min_smem_per_input * n_inputs >
(int64_t)dev_prop->sharedMemPerBlockOptin) {
return false;
}
return true;
}

Expand All @@ -3133,8 +3142,7 @@ class TmaOuterReductionTest
};

TEST_P(TmaOuterReductionTest, Sum) {
auto [outer_size, iter_size] = GetParam();
auto dtype = DataType::Float;
auto [outer_size, iter_size, n_inputs, dtype] = GetParam();
int64_t dtype_bytes = dataTypeSizeByte(dtype);

// TMA requires the stride (iter_size * dtype_bytes) to be 16-byte aligned
Expand All @@ -3143,56 +3151,78 @@ TEST_P(TmaOuterReductionTest, Sum) {
return;
}

std::vector<int64_t> shape = {outer_size, iter_size};

auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard fg(fusion_ptr.get());
Fusion& fusion = *fusion_ptr;

auto tv0 = makeContigTensor(2);
fusion.addInput(tv0);
auto tv1 = sum(tv0, {0}); // reduce along axis 0 (outer reduction)
fusion.addOutput(tv1);
// Create n_inputs tensors and sum them element-wise before reducing
std::vector<TensorView*> inputs;
for (int64_t i = 0; i < n_inputs; i++) {
auto tv = makeContigTensor(2, dtype);
fusion.addInput(tv);
inputs.push_back(tv);
}
auto accum = inputs[0];
for (int64_t i = 1; i < n_inputs; i++) {
accum = add(accum, inputs[i]);
}
auto reduced = sum(accum, {0});
fusion.addOutput(reduced);

auto unscheduled_fusion_copy = fusion;

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn(shape, options);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
std::vector<c10::IValue> aten_inputs;
for (int64_t i = 0; i < n_inputs; i++) {
aten_inputs.push_back(at::randn({outer_size, iter_size}, options));
}

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs({t0});
auto outputs = executor_cache.runFusionWithInputs(aten_inputs);

if (expectOuterTmaUsed(outer_size, iter_size, dtype_bytes)) {
if (expectOuterTmaUsed(outer_size, iter_size, n_inputs, dtype_bytes)) {
EXPECT_TRUE(tma_reduction_check::isOuterTmaParams(executor_cache))
<< "Expected outer TMA scheduler for outer_size=" << outer_size
<< " iter_size=" << iter_size;
<< " iter_size=" << iter_size << " n_inputs=" << n_inputs;
}

testValidate(&unscheduled_fusion_copy, outputs, {t0}, __LINE__, __FILE__, "");
testValidate(
&unscheduled_fusion_copy, outputs, aten_inputs, __LINE__, __FILE__, "");
}

INSTANTIATE_TEST_SUITE_P(
,
TmaOuterReductionTest,
testing::Combine(
testing::ValuesIn([] { // outer_size
std::vector<int64_t> vals;
for (int64_t v = 256; v <= 65536; v *= 4) {
vals.push_back(v);
}
return vals;
}()),
testing::ValuesIn([] { // iter_size
std::vector<int64_t> vals;
for (int64_t v = 256; v <= 65536; v *= 4) {
vals.push_back(v);
}
return vals;
}())),
testing::Values(
// Size sweep with single input, float
TmaOuterReductionTestParams{256, 256, 1, DataType::Float},
TmaOuterReductionTestParams{256, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{256, 65536, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 256, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{4096, 65536, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 256, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 4096, 1, DataType::Float},
TmaOuterReductionTestParams{65536, 65536, 1, DataType::Float},
// Multi-input (exercises smem budget / tile shrinking)
TmaOuterReductionTestParams{4096, 1024, 2, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 3, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 5, DataType::Float},
TmaOuterReductionTestParams{4096, 1024, 2, DataType::Half},
TmaOuterReductionTestParams{4096, 1024, 3, DataType::Half},
TmaOuterReductionTestParams{4096, 1024, 5, DataType::Half}),
([](const testing::TestParamInfo<TmaOuterReductionTestParams>& info) {
auto [outer_size, iter_size] = info.param;
return "outer_" + std::to_string(outer_size) + "_iter_" +
auto [outer_size, iter_size, n_inputs, dtype] = info.param;
std::string name = "outer_" + std::to_string(outer_size) + "_iter_" +
std::to_string(iter_size);
if (n_inputs > 1) {
name += "_" + std::to_string(n_inputs) + "inputs";
}
if (dtype != DataType::Float) {
name += "_half";
}
return name;
}));

} // namespace nvfuser
Loading