Skip to content

Commit fc60235

Browse files
authored
add extra domain restriction to transpose scheduler (#5884)
Solve issue #5883 for better performance. Use pointwise scheduler when two transpose groups are caused by broadcast domains.
1 parent 69da2d1 commit fc60235

File tree

2 files changed

+121
-14
lines changed

2 files changed

+121
-14
lines changed

csrc/scheduler/tools/domain_map.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <scheduler/tools/domain_map.h>
99
#include <scheduler/utils.h>
1010

11+
#include <ranges>
12+
1113
namespace nvfuser {
1214
namespace scheduler_tools {
1315

@@ -515,10 +517,47 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) {
515517
if (ref1 == nullptr || ref2 == nullptr) {
516518
return false;
517519
}
520+
518521
// reference 1 is the global reference, so it must have dim mapped the
519522
// innermost dim of both groups
520523
auto innermost2 = scheduler_utils::innerMostAllocDim(ref2);
521-
return domain_map.getMappedAllocDimIn(ref1, innermost2) != nullptr;
524+
auto mapped_id = domain_map.getMappedAllocDimIn(ref1, innermost2);
525+
if (mapped_id == nullptr) {
526+
return false;
527+
}
528+
529+
// For grouping caused by permutation, the corresponding allocation domains
530+
// should not be all mapped to each other. If they are, it means the two
531+
// groups are due to broadcast. In this case, they are not considered as valid
532+
// groups since the broadcast tensor has a smaller size and pointwise
533+
// scheduler handles broadcast well through unrolling and caching at all
534+
// levels. For example, in TransposeTest.NoTransposeMaverick17B, two inputs
535+
// are tv0[i0, i1] and tv1[i2, b3] where i0/i2 and i1/b3 are mapped to each
536+
// other. However, tv0 and tv1 are in two different groups because of the
537+
// broadcast. In this case, we should use the pointwise scheduler instead of
538+
// the transpose scheduler.
539+
const auto& ref1_loop = ref1->getMaybeAllocationDomain();
540+
const auto& ref2_loop = ref2->getMaybeAllocationDomain();
541+
const auto& ca_map = domain_map.getComputeAtMap();
542+
const bool all_mapped = std::ranges::equal(
543+
ref1_loop, ref2_loop, [&](IterDomain* id1, IterDomain* id2) {
544+
return ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE);
545+
});
546+
if (all_mapped) {
547+
// Not required, just to validate the assumption that all_mapped implies
548+
// any_bcast
549+
const bool any_bcast =
550+
std::ranges::any_of(
551+
ref1_loop, [](IterDomain* id) { return id->isBroadcast(); }) ||
552+
std::ranges::any_of(
553+
ref2_loop, [](IterDomain* id) { return id->isBroadcast(); });
554+
NVF_ERROR(
555+
any_bcast,
556+
"all_mapped implies any_bcast, ca_map:\n",
557+
ca_map.toString());
558+
return false;
559+
}
560+
return true;
522561
}
523562

524563
int64_t TransposeDomainMap::getInnerLeafDim(

tests/cpp/test_transpose.cpp

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "scheduler/transpose.h"
2121
#include "scheduler/utils.h"
2222
#include "tests/cpp/utils.h"
23+
#include "type.h"
2324
#include "validator_utils.h"
2425

2526
namespace nvfuser {
@@ -284,11 +285,14 @@ TEST_F(TransposeTest, FusionScheduleTransposeNoReference) {
284285

285286
// x->broadcast--add->z
286287
// y->broadcast-/
288+
// pointwise: 61%
289+
// transpose: 39%
287290
TEST_F(TransposeTest, FusionScheduleBroadcastOnly) {
288291
for (bool contig0 : {true, false}) {
289292
for (bool contig1 : {true, false}) {
290-
Fusion fusion;
291-
FusionGuard fg(&fusion);
293+
auto fusion_ptr = std::make_unique<Fusion>();
294+
FusionGuard fg(fusion_ptr.get());
295+
Fusion& fusion = *fusion_ptr;
292296
auto tv0 = contig0 ? makeContigConcreteTensor({-1, 1, -1})
293297
: makeConcreteTensor({-1, 1, -1});
294298
auto tv1 = contig1 ? makeContigConcreteTensor({-1, -1, 1})
@@ -302,10 +306,24 @@ TEST_F(TransposeTest, FusionScheduleBroadcastOnly) {
302306
at::Tensor input0 = at::randn({1024, 1, 256}, options);
303307
at::Tensor input1 = at::randn({1024, 1024, 1}, options);
304308

305-
auto cg_outputs =
306-
scheduleAndRun(&fusion, SchedulerType::Transpose, {input0, input1})
307-
.outputs;
308-
testValidate(&fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__);
309+
FusionExecutorCache executor_cache(std::move(fusion_ptr));
310+
auto outputs = executor_cache.runFusionWithInputs({input0, input1});
311+
auto runtime = executor_cache.getMostRecentKernelRuntime();
312+
auto heuristic = runtime->schedulerHeuristics()
313+
->heuristicsList()
314+
.at(0)
315+
.get()
316+
->scheduler_type;
317+
NVF_CHECK(
318+
heuristic == SchedulerType::PointWise,
319+
"Unexpected heuristic: ",
320+
heuristic);
321+
testValidate(
322+
executor_cache.fusion(),
323+
outputs,
324+
{input0, input1},
325+
__LINE__,
326+
__FILE__);
309327
}
310328
}
311329
}
@@ -627,8 +645,9 @@ TEST_F(TransposeTest, FusionTransposeViewSelfMapping) {
627645
// t2->broadcast->sub->mul->relu->t6
628646
// t1------------------'
629647
TEST_F(TransposeTest, FusionScheduleTransposeMissingDim) {
630-
Fusion fusion;
631-
FusionGuard fg(&fusion);
648+
auto fusion_ptr = std::make_unique<Fusion>();
649+
FusionGuard fg(fusion_ptr.get());
650+
Fusion& fusion = *fusion_ptr;
632651

633652
auto tv0 = makeContigTensor(3);
634653
auto tv1 = makeContigConcreteTensor({1, -1, 1});
@@ -647,12 +666,24 @@ TEST_F(TransposeTest, FusionScheduleTransposeMissingDim) {
647666
at::Tensor input1 = at::randn({1, 512, 1}, options);
648667
at::Tensor input2 = at::randn({512}, options);
649668

650-
auto cg_outputs =
651-
scheduleAndRun(
652-
&fusion, SchedulerType::Transpose, {input0, input1, input2})
653-
.outputs;
669+
FusionExecutorCache executor_cache(std::move(fusion_ptr));
670+
auto outputs = executor_cache.runFusionWithInputs({input0, input1, input2});
671+
auto runtime = executor_cache.getMostRecentKernelRuntime();
672+
auto heuristic = runtime->schedulerHeuristics()
673+
->heuristicsList()
674+
.at(0)
675+
.get()
676+
->scheduler_type;
677+
NVF_CHECK(
678+
heuristic == SchedulerType::PointWise,
679+
"Unexpected heuristic: ",
680+
heuristic);
654681
testValidate(
655-
&fusion, cg_outputs, {input0, input1, input2}, __LINE__, __FILE__);
682+
executor_cache.fusion(),
683+
outputs,
684+
{input0, input1, input2},
685+
__LINE__,
686+
__FILE__);
656687
}
657688

658689
// x->sin->transpose->cos->y
@@ -1407,4 +1438,41 @@ TEST_F(TransposeTest, DanglingBroadcastIssue4957) {
14071438
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);
14081439
}
14091440

1441+
TEST_F(TransposeTest, NoTransposeMaverick17B) {
1442+
auto fusion_ptr = std::make_unique<Fusion>();
1443+
FusionGuard fg(fusion_ptr.get());
1444+
Fusion& fusion = *fusion_ptr;
1445+
1446+
auto dtype = DataType::BFloat16;
1447+
auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype);
1448+
auto tv1 = makeContigConcreteTensor({262144, 1}, dtype);
1449+
fusion.addInput(tv0);
1450+
fusion.addInput(tv1);
1451+
auto tv2 = castOp(DataType::Float, tv0);
1452+
auto tv3 = castOp(DataType::Float, tv1);
1453+
auto tv4 = mul(tv2, tv3);
1454+
auto tv5 = castOp(dtype, tv4);
1455+
fusion.addOutput(tv5);
1456+
1457+
auto options =
1458+
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
1459+
at::Tensor input0 = at::randn({262144, 5120}, options);
1460+
at::Tensor input1 = at::randn({262144, 1}, options);
1461+
1462+
FusionExecutorCache executor_cache(std::move(fusion_ptr));
1463+
auto outputs = executor_cache.runFusionWithInputs({input0, input1});
1464+
auto runtime = executor_cache.getMostRecentKernelRuntime();
1465+
auto heuristic = runtime->schedulerHeuristics()
1466+
->heuristicsList()
1467+
.at(0)
1468+
.get()
1469+
->scheduler_type;
1470+
NVF_CHECK(
1471+
heuristic == SchedulerType::PointWise,
1472+
"Unexpected heuristic: ",
1473+
heuristic);
1474+
testValidate(
1475+
executor_cache.fusion(), outputs, {input0, input1}, __LINE__, __FILE__);
1476+
}
1477+
14101478
} // namespace nvfuser

0 commit comments

Comments
 (0)