@@ -127,7 +127,7 @@ class LayoutRematerialization {
127127 }
128128
129129 void cleanup ();
130- void backwardRematerialization ();
130+ bool backwardRematerialization ();
131131 void backwardRematerialization (ConvertLayoutOp convertOp);
132132 // TODO: Merge the three hoistConvert*(); functions as they are duplicate code
133133 void hoistConvertDotOperand ();
@@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
10191019 return success ();
10201020}
10211021
1022- void LayoutRematerialization::backwardRematerialization () {
1022+ bool LayoutRematerialization::backwardRematerialization () {
1023+ bool changed = false ;
10231024 // Go through each ConvertLayoutOp.
10241025 SmallVector<ConvertLayoutOp> convertOps;
10251026 funcOp.walk (
@@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() {
10311032 // backward slices.
10321033 addRematValue (convertOp.getSrc (), convertOp.getType ().getEncoding (),
10331034 convertOp.getResult ());
1035+ } else {
1036+ changed = true ;
10341037 }
10351038 }
1039+ return changed;
10361040}
10371041
10381042void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast () {
@@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15931597 rewriteSlice (slice, layout, convertOp, mapping);
15941598}
15951599
1596- void backwardRematerialization (ModuleOp module ) {
1597- module .walk ([](FuncOp funcOp) {
1600+ bool backwardRematerialization (ModuleOp module ) {
1601+ bool changed = false ;
1602+ module .walk ([&](FuncOp funcOp) {
15981603 LayoutRematerialization layoutRemat (funcOp);
1599- layoutRemat.backwardRematerialization ();
1604+ changed |= layoutRemat.backwardRematerialization ();
16001605 layoutRemat.cleanup ();
16011606 });
1607+ return changed;
16021608}
16031609
16041610void hoistConvert (ModuleOp module ) {
@@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass
16591665
16601666 cleanupConvertOps ();
16611667
1662- // 2. For remaining convert ops, try to rematerialize the slice of producer
1663- // operation to avoid having to convert.
1664- backwardRematerialization (m);
1665- LLVM_DEBUG ({
1666- DBGS () << " Module after backward remat:\n " ;
1667- m.dump ();
1668- });
1669-
1670- // Cleanup dummy converts created during backward remat.
1671- cleanupConvertOps ();
1672-
1668+ bool changed = false ;
1669+ do {
1670+ changed = false ;
1671+ // 2. For remaining convert ops, try to rematerialize the slice of
1672+ // producer operation to avoid having to convert.
1673+ changed = backwardRematerialization (m);
1674+ LLVM_DEBUG ({
1675+ DBGS () << " Module after backward remat:\n " ;
1676+ m.dump ();
1677+ });
1678+
1679+ // Cleanup dummy converts created during backward remat.
1680+ cleanupConvertOps ();
1681+ } while (changed);
16731682 // 3. For remaining converts, try to hoist them above cast generating larger
16741683 // size types in order to reduce the cost of the convert op.
16751684 hoistConvert (m);
0 commit comments