Skip to content

Commit 1eacd9b

Browse files
authored
Drop codegen support of gather (but not takeAlongAxis) (#5907)
Gather allows non-gathered indices to have smaller output dimensions, which complicates indexing and is not yet supported by TensorIndexer and is supported only by the legacy indexer. Note that takeAlongAxis, which is a limited case of gather, is supported. The motivation is to remove the legacy indexer. This is the only remaining fallback case. One way to support it is to decompose it into a takeAlongAxis and slice. For now, this PR disables codegen of gather and delegates to ExprEval. Note that the cross-entropy benchmark does use gather rather than takeAlongAxis. There's a pending change needed in Thunder. See #3924 (comment). While this is a perf regression, at this point I think it'd more important to remove the large technical debt. In a follow-up PR, I'll remove the legacy indexer. This PR just inserts an assertion that no fallback is necessary, which should be true by the scheduler changes.
1 parent 1bded43 commit 1eacd9b

File tree

6 files changed

+18
-136
lines changed

6 files changed

+18
-136
lines changed

csrc/id_model/indexing.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
namespace nvfuser {
3535

3636
TensorIndexer::TensorIndexer(IdModel& id_model) : id_model_(id_model) {
37+
NVF_ERROR(isSupported(id_model.fusion()));
38+
3739
buildLoopIndexMap();
3840

3941
if (isDebugDumpEnabled(DebugDumpOption::IndexingVerbose)) {

csrc/scheduler/expr_eval_sched.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) {
6464
// TODO: remove IndexPutAccumulateOp
6565
if (exprs.front()
6666
->isOneOf<
67+
GatherOp,
6768
ScatterOp,
6869
SdpaFwdOp,
6970
SdpaBwdOp,

csrc/scheduler/registry.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) {
6464
return false;
6565
}
6666

67+
// Support of non-exact gather was dropped when the legacy indexer was
68+
// deprecated
69+
if (std::ranges::any_of(
70+
ir_utils::getOpsOfType<GatherOp>(fusion),
71+
[](GatherOp* gather) { return !gather->exactSizes(); })) {
72+
scheduler_debug_utils::canScheduleRejectReason(
73+
scheduler_type, "Non-exact gather ops");
74+
return false;
75+
}
76+
6777
// Fusions with `MatmulOp, LinearOp, MmaOp` can only be accepted by Matmul
6878
// scheduler.
6979
if (scheduler_type != SchedulerType::Matmul &&

tests/cpp/test_gather.cpp

Lines changed: 1 addition & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ TEST_F(GatherTest, TakeAlongAxisIntermediateTensorReduction1) {
582582

583583
validateSegmentation(
584584
executor_cache.getMostRecentKernelRuntime(),
585-
{SchedulerType::Reduction, SchedulerType::PointWise});
585+
{SchedulerType::Reduction, SchedulerType::ExprEval});
586586

587587
testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__);
588588
}
@@ -1126,137 +1126,4 @@ TEST_F(GatherTest, TakeAlongAxisCrossEntropyLoss) {
11261126
testValidate(fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__);
11271127
}
11281128

1129-
// Test grouped reduction on IterType::GatherScatter
1130-
TEST_F(GatherTest, GatherIterGoupedReduction) {
1131-
const int max_dim_size = 128;
1132-
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1133-
auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
1134-
1135-
int rank = 3;
1136-
int dim = 2;
1137-
1138-
auto fusion_ptr = std::make_unique<Fusion>();
1139-
Fusion& fusion = *fusion_ptr.get();
1140-
FusionGuard fg(&fusion);
1141-
1142-
TensorView* tv1 = makeContigTensor(rank);
1143-
TensorView* tv_idx = makeContigTensor(rank, DataType::Int);
1144-
fusion.addInput(tv1);
1145-
fusion.addInput(tv_idx);
1146-
auto tv_gather = gather(tv1, dim, tv_idx);
1147-
auto tv_sum = sum(tv_gather, {0}, false);
1148-
fusion.addOutput(tv_sum);
1149-
1150-
// simply gather all elements
1151-
auto input_dims =
1152-
std::vector<int64_t>({max_dim_size, max_dim_size, max_dim_size});
1153-
auto index_dims = input_dims;
1154-
std::vector<int64_t> input2_dims(rank - 1, 0);
1155-
for (int idim = 0; idim < rank - 1; ++idim) {
1156-
input2_dims[idim] = index_dims[idim + 1];
1157-
}
1158-
1159-
at::Tensor t0 = at::randn(input_dims, options);
1160-
at::Tensor idx = at::randint(0, input_dims[dim], index_dims, options_i);
1161-
1162-
auto reduction_scheduler =
1163-
SchedulerEntry::makeSchedulerInstance(SchedulerType::Reduction);
1164-
SchedulerRuntimeInfo runtime_info(&fusion, {t0, idx});
1165-
auto heuristic_params =
1166-
reduction_scheduler->computeHeuristics(&fusion, runtime_info);
1167-
auto rparams = heuristic_params->as<ReductionParams>();
1168-
1169-
// Enforce vectorization so we can group them
1170-
const int vect_factor = 2;
1171-
rparams->vectorize_iter_dom = true;
1172-
rparams->unroll_factor_iter_dom = vect_factor;
1173-
// Enforce grid reduction, which requires a determined BIDy
1174-
// If the heuristic does not have a BIDy, bind it to 2
1175-
rparams->cross_grid_inner_reduction = true;
1176-
rparams->split_grid_dim_inner_reduction = true;
1177-
rparams->grid_dim_inner_reduction = ParallelType::BIDy;
1178-
if (!rparams->lparams.hasDim(ParallelType::BIDy)) {
1179-
rparams->lparams.bind(2L, ParallelType::BIDy);
1180-
}
1181-
1182-
reduction_scheduler->schedule(&fusion, rparams);
1183-
1184-
// lowering & check iteration grouped reductions
1185-
GpuLower gpulw(&fusion);
1186-
gpulw.run();
1187-
NVF_CHECK(
1188-
gpulw.kernel()->summary().has_iter_grouped_reductions,
1189-
"There must be iter domain grouped reductions.");
1190-
NVF_CHECK(
1191-
gpulw.kernel()->summary().num_grouped_iterations == vect_factor,
1192-
"Expected ",
1193-
vect_factor,
1194-
" grouped iterations, found ",
1195-
gpulw.kernel()->summary().num_grouped_iterations);
1196-
1197-
KernelExecutor ke;
1198-
auto lparams = rparams->lparams;
1199-
ke.compile(&fusion, {t0, idx}, lparams);
1200-
auto cg_outputs = ke.run({t0, idx}, {}, lparams);
1201-
1202-
auto t_gather = at::gather(t0, dim, idx);
1203-
testValidate(
1204-
&fusion,
1205-
cg_outputs,
1206-
{t0, idx},
1207-
{t_gather.sum(0)},
1208-
__LINE__,
1209-
__FILE__,
1210-
"",
1211-
lparams);
1212-
}
1213-
1214-
TEST_F(GatherTest, SameTvUsedAsLookupAndIndex) {
1215-
auto fusion_ptr = std::make_unique<Fusion>();
1216-
Fusion& fusion = *fusion_ptr.get();
1217-
FusionGuard fg(&fusion);
1218-
1219-
// Create three input tensors
1220-
auto tv0 = makeContigTensor(2);
1221-
auto tv1 = makeContigTensor(2, DataType::Int);
1222-
auto tv2 = makeContigTensor(2, DataType::Int);
1223-
fusion.addInput(tv0);
1224-
fusion.addInput(tv1);
1225-
fusion.addInput(tv2);
1226-
1227-
auto tv3 = gather(tv0, 1, tv1);
1228-
auto tv4 = gather(tv1, 1, tv2);
1229-
auto tv5 = castOp(DataType::Float, tv4);
1230-
auto tv6 = add(tv3, tv5);
1231-
fusion.addOutput(tv6);
1232-
1233-
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1234-
auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
1235-
1236-
// Create test tensors
1237-
std::vector<int64_t> dims{4, 6};
1238-
at::Tensor t0 = at::randn(dims, options);
1239-
at::Tensor t1 = at::randint(0, dims[1], dims, options_i);
1240-
at::Tensor t2 = at::randint(0, dims[1], dims, options_i);
1241-
1242-
FusionExecutorCache executor_cache(std::move(fusion_ptr));
1243-
auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1, t2});
1244-
1245-
auto runtime = executor_cache.getMostRecentKernelRuntime();
1246-
auto scheduled_fusion = runtime->executors()
1247-
.back()
1248-
->as<KernelExecutor>()
1249-
->compiledKernel()
1250-
->kernel();
1251-
auto tv1_uses = scheduled_fusion->inputs().at(1)->uses();
1252-
EXPECT_EQ(tv1_uses.size(), 2);
1253-
EXPECT_THAT(
1254-
tv1_uses,
1255-
testing::UnorderedElementsAre(
1256-
testing::Truly([](Expr* e) { return e->isA<GatherOp>(); }),
1257-
testing::Truly([](Expr* e) { return e->isA<LoadStoreOp>(); })));
1258-
1259-
// Validate the result
1260-
testValidate(&fusion, cg_outputs, {t0, t1, t2}, __LINE__, __FILE__);
1261-
}
12621129
} // namespace nvfuser

tests/cpp/test_persistent_buffer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1941,7 +1941,9 @@ TEST_F(PersistentBufferTest, BufferGatherLookupTv) {
19411941
auto tv2 = sum(tv1, {1});
19421942
auto tv3 = broadcast(tv2, {false, true});
19431943
auto tv4 = broadcast(index_tv, {false, true});
1944-
auto tv5 = gather(tv0, 1, tv4);
1944+
// Use takeAlongAxis rather than gather as codegen does not support
1945+
// the latter
1946+
auto tv5 = takeAlongAxis(tv0, tv4, 1);
19451947
auto tv6 = maybeCastOp(DataType::BFloat16, tv5);
19461948
auto tv7 = add(tv3, tv6);
19471949
auto tv8 = add(tv1, tv7);

tests/cpp/test_reduction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2563,7 +2563,7 @@ TEST_F(ReductionTest, CrossEntropyGatherPattern) {
25632563
fusion.addInput(labels);
25642564

25652565
auto tv2 = broadcast(labels, {false, true});
2566-
auto tv3 = gather(log_probs, 1, tv2);
2566+
auto tv3 = takeAlongAxis(log_probs, tv2, 1);
25672567
auto tv4 = squeeze(tv3, std::vector<bool>({false, true}));
25682568

25692569
fusion.addOutput(tv4);

0 commit comments

Comments
 (0)