Skip to content

Commit ce3d636

Browse files
ThomasRaouxanmyachev
authored andcommitted
[BACKEND] run remove backward prop until a fix point (#8776)
1 parent e5d0ec4 commit ce3d636

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class LayoutRematerialization {
127127
}
128128

129129
void cleanup();
130-
void backwardRematerialization();
130+
bool backwardRematerialization();
131131
void backwardRematerialization(ConvertLayoutOp convertOp);
132132
// TODO: Merge the three hoistConvert*(); functions as they are duplicate code
133133
void hoistConvertDotOperand();
@@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
10191019
return success();
10201020
}
10211021

1022-
void LayoutRematerialization::backwardRematerialization() {
1022+
bool LayoutRematerialization::backwardRematerialization() {
1023+
bool changed = false;
10231024
// Go through each ConvertLayoutOp.
10241025
SmallVector<ConvertLayoutOp> convertOps;
10251026
funcOp.walk(
@@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() {
10311032
// backward slices.
10321033
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
10331034
convertOp.getResult());
1035+
} else {
1036+
changed = true;
10341037
}
10351038
}
1039+
return changed;
10361040
}
10371041

10381042
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
@@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15931597
rewriteSlice(slice, layout, convertOp, mapping);
15941598
}
15951599

1596-
void backwardRematerialization(ModuleOp module) {
1597-
module.walk([](FuncOp funcOp) {
1600+
bool backwardRematerialization(ModuleOp module) {
1601+
bool changed = false;
1602+
module.walk([&](FuncOp funcOp) {
15981603
LayoutRematerialization layoutRemat(funcOp);
1599-
layoutRemat.backwardRematerialization();
1604+
changed |= layoutRemat.backwardRematerialization();
16001605
layoutRemat.cleanup();
16011606
});
1607+
return changed;
16021608
}
16031609

16041610
void hoistConvert(ModuleOp module) {
@@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass
16591665

16601666
cleanupConvertOps();
16611667

1662-
// 2. For remaining convert ops, try to rematerialize the slice of producer
1663-
// operation to avoid having to convert.
1664-
backwardRematerialization(m);
1665-
LLVM_DEBUG({
1666-
DBGS() << "Module after backward remat:\n";
1667-
m.dump();
1668-
});
1669-
1670-
// Cleanup dummy converts created during backward remat.
1671-
cleanupConvertOps();
1672-
1668+
bool changed = false;
1669+
do {
1670+
changed = false;
1671+
// 2. For remaining convert ops, try to rematerialize the slice of
1672+
// producer operation to avoid having to convert.
1673+
changed = backwardRematerialization(m);
1674+
LLVM_DEBUG({
1675+
DBGS() << "Module after backward remat:\n";
1676+
m.dump();
1677+
});
1678+
1679+
// Cleanup dummy converts created during backward remat.
1680+
cleanupConvertOps();
1681+
} while (changed);
16731682
// 3. For remaining converts, try to hoist them above cast generating larger
16741683
// size types in order to reduce the cost of the convert op.
16751684
hoistConvert(m);

test/TritonGPU/combine.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,11 +2500,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
25002500
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked2>
25012501
%3 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
25022502
%4 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
2503-
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>)
2504-
// FIXME: The optimal number of conversions should be 4.
2505-
// CHECK-COUNT-5: convert_layout
2503+
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>)
2504+
// CHECK-COUNT-4: convert_layout
25062505
// CHECK-NOT: convert_layout
2507-
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
2506+
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
25082507
// CHECK: }
25092508
// CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
25102509
%5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 {

0 commit comments

Comments
 (0)