Skip to content

Commit 03b40df

Browse files
authored
[BACKEND] Cleanup IR after backward remat in RemoveLayoutConversions (triton-lang#5659)
Backward remat creates a lot of dummy convert ops that should be removed before running the hosting phase otherwise we spend a lot of time looking at dummy convert. This doesn't change the result of the pass but it helps reduce compile time and prevent combinatory explosion.
1 parent febe2a1 commit 03b40df

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,22 @@ class TritonGPURemoveLayoutConversionsPass
11731173
: public impl::TritonGPURemoveLayoutConversionsBase<
11741174
TritonGPURemoveLayoutConversionsPass> {
11751175
public:
1176+
// Cleanup convert ops.
1177+
void cleanupConvertOps() {
1178+
MLIRContext *context = &getContext();
1179+
ModuleOp m = getOperation();
1180+
RewritePatternSet cleanUpPatterns(context);
1181+
ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context);
1182+
if (applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns)).failed()) {
1183+
signalPassFailure();
1184+
}
1185+
1186+
LLVM_DEBUG({
1187+
DBGS() << "Module after canonicalizing:\n";
1188+
m.dump();
1189+
});
1190+
}
1191+
11761192
void runOnOperation() override {
11771193
MLIRContext *context = &getContext();
11781194
ModuleOp m = getOperation();
@@ -1191,16 +1207,7 @@ class TritonGPURemoveLayoutConversionsPass
11911207
m.dump();
11921208
});
11931209

1194-
RewritePatternSet cleanUpPatterns(context);
1195-
ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context);
1196-
if (applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns)).failed()) {
1197-
signalPassFailure();
1198-
}
1199-
1200-
LLVM_DEBUG({
1201-
DBGS() << "Module after canonicalizing:\n";
1202-
m.dump();
1203-
});
1210+
cleanupConvertOps();
12041211

12051212
// 2. For remaining convert ops, try to rematerialize the slice of producer
12061213
// operation to avoid having to convert.
@@ -1210,6 +1217,9 @@ class TritonGPURemoveLayoutConversionsPass
12101217
m.dump();
12111218
});
12121219

1220+
// Cleanup dummy converts created during backward remat.
1221+
cleanupConvertOps();
1222+
12131223
// 3. For remaining converts, try to hoist them above cast generating larger
12141224
// size types in order to reduce the cost of the convert op.
12151225
hoistConvert(m);

0 commit comments

Comments
 (0)