Skip to content

Commit 8f613ab

Browse files
committed
[mlir] Enable disable folding in dialect conversion
Previously this only happened post checking if the op isn't legal, but was done unconditionally post (and before other legalization patterns). Add option to not attempt folding and one to do so as last resort.
1 parent 6d4a093 commit 8f613ab

File tree

7 files changed

+86
-10
lines changed

7 files changed

+86
-10
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,16 @@ class PDLConversionConfig final {
11581158
// ConversionConfig
11591159
//===----------------------------------------------------------------------===//
11601160

1161+
/// An enum to control folding behavior during dialect conversion.
1162+
enum class DialectConversionFoldingMode {
1163+
/// Never attempt to fold.
1164+
Never,
1165+
/// Only attempt to fold not legal operations before applying patterns.
1166+
BeforePatterns,
1167+
/// Only attempt to fold not legal operations after applying patterns.
1168+
AfterPatterns,
1169+
};
1170+
11611171
/// Dialect conversion configuration.
11621172
struct ConversionConfig {
11631173
/// An optional callback used to notify about match failure diagnostics during
@@ -1240,6 +1250,10 @@ struct ConversionConfig {
12401250
/// your patterns do not trigger any IR rollbacks. For details, see
12411251
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
12421252
bool allowPatternRollback = true;
1253+
1254+
/// The folding mode to use during conversion.
1255+
DialectConversionFoldingMode foldingMode =
1256+
DialectConversionFoldingMode::BeforePatterns;
12431257
};
12441258

12451259
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,15 +2197,16 @@ OperationLegalizer::legalize(Operation *op,
21972197
return success();
21982198
}
21992199

2200-
// If the operation isn't legal, try to fold it in-place.
2201-
// TODO: Should we always try to do this, even if the op is
2202-
// already legal?
2203-
if (succeeded(legalizeWithFold(op, rewriter))) {
2204-
LLVM_DEBUG({
2205-
logSuccess(logger, "operation was folded");
2206-
logger.startLine() << logLineComment;
2207-
});
2208-
return success();
2200+
// If the operation is not legal, try to fold it in-place if the folding mode
2201+
// is 'BeforePatterns'. 'Never' will skip this.
2202+
if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2203+
if (succeeded(legalizeWithFold(op, rewriter))) {
2204+
LLVM_DEBUG({
2205+
logSuccess(logger, "operation was folded");
2206+
logger.startLine() << logLineComment;
2207+
});
2208+
return success();
2209+
}
22092210
}
22102211

22112212
// Otherwise, we need to apply a legalization pattern to this operation.
@@ -2217,6 +2218,18 @@ OperationLegalizer::legalize(Operation *op,
22172218
return success();
22182219
}
22192220

2221+
// If the operation can't be legalized via patterns, try to fold it in-place
2222+
// if the folding mode is 'AfterPatterns'.
2223+
if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2224+
if (succeeded(legalizeWithFold(op, rewriter))) {
2225+
LLVM_DEBUG({
2226+
logSuccess(logger, "operation was folded");
2227+
logger.startLine() << logLineComment;
2228+
});
2229+
return success();
2230+
}
2231+
}
2232+
22202233
LLVM_DEBUG({
22212234
logFailure(logger, "no matched legalization pattern");
22222235
logger.startLine() << logLineComment;
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=after-patterns" | FileCheck %s
2+
3+
// CHECK-LABEL: @fold_legalization
4+
func.func @fold_legalization() -> i32 {
5+
// CHECK-NOT: op_in_place_self_fold
6+
// CHECK: 97
7+
%1 = "test.op_in_place_self_fold"() : () -> (i32)
8+
"test.return"(%1) : (i32) -> ()
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=before-patterns" | FileCheck %s
2+
3+
// CHECK-LABEL: @fold_legalization
4+
func.func @fold_legalization() -> i32 {
5+
// CHECK: op_in_place_self_fold
6+
// CHECK-SAME: folded
7+
%1 = "test.op_in_place_self_fold"() : () -> (i32)
8+
"test.return"(%1) : (i32) -> ()
9+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: mlir-opt %s -allow-unregistered-dialect -test-legalize-patterns="test-legalize-folding-mode=never" | FileCheck %s
2+
3+
// CHECK-LABEL: @remove_foldable_op(
4+
func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
5+
// Check that op was not folded.
6+
// CHECK: "test.op_with_region_fold"
7+
%0 = "test.op_with_region_fold"(%arg0) ({
8+
"foo.op_with_region_terminator"() : () -> ()
9+
}) : (i32) -> (i32)
10+
"test.return"(%0) : (i32) -> ()
11+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,8 @@ def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
14781478
let results = (outs I32);
14791479
let hasFolder = 1;
14801480
}
1481+
def : Pat<(TestOpInPlaceSelfFold:$op $_),
1482+
(TestOpConstant ConstantAttr<I32Attr, "97">)>;
14811483

14821484
// Test op that simply returns success.
14831485
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,7 @@ struct TestLegalizePatternDriver
14021402
ConversionTarget target(getContext());
14031403
target.addLegalOp<ModuleOp>();
14041404
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1405-
TerminatorOp, OneRegionOp>();
1405+
TerminatorOp, OneRegionOp, TestOpConstant>();
14061406
target.addLegalOp(OperationName("test.legal_op", &getContext()));
14071407
target
14081408
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
@@ -1457,6 +1457,7 @@ struct TestLegalizePatternDriver
14571457
DumpNotifications dumpNotifications;
14581458
config.listener = &dumpNotifications;
14591459
config.unlegalizedOps = &unlegalizedOps;
1460+
config.foldingMode = foldingMode;
14601461
if (failed(applyPartialConversion(getOperation(), target,
14611462
std::move(patterns), config))) {
14621463
getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1476,6 +1477,7 @@ struct TestLegalizePatternDriver
14761477

14771478
ConversionConfig config;
14781479
DumpNotifications dumpNotifications;
1480+
config.foldingMode = foldingMode;
14791481
config.listener = &dumpNotifications;
14801482
if (failed(applyFullConversion(getOperation(), target,
14811483
std::move(patterns), config))) {
@@ -1490,6 +1492,7 @@ struct TestLegalizePatternDriver
14901492
// Analyze the convertible operations.
14911493
DenseSet<Operation *> legalizedOps;
14921494
ConversionConfig config;
1495+
config.foldingMode = foldingMode;
14931496
config.legalizableOps = &legalizedOps;
14941497
if (failed(applyAnalysisConversion(getOperation(), target,
14951498
std::move(patterns), config)))
@@ -1510,6 +1513,21 @@ struct TestLegalizePatternDriver
15101513
clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"),
15111514
clEnumValN(ConversionMode::Partial, "partial",
15121515
"Perform a partial conversion"))};
1516+
1517+
Option<DialectConversionFoldingMode> foldingMode{
1518+
*this, "test-legalize-folding-mode",
1519+
llvm::cl::desc("The folding mode to use with the test driver"),
1520+
llvm::cl::init(DialectConversionFoldingMode::BeforePatterns),
1521+
llvm::cl::values(clEnumValN(DialectConversionFoldingMode::Never, "never",
1522+
"Never attempt to fold"),
1523+
clEnumValN(DialectConversionFoldingMode::BeforePatterns,
1524+
"before-patterns",
1525+
"Only attempt to fold not legal operations "
1526+
"before applying patterns"),
1527+
clEnumValN(DialectConversionFoldingMode::AfterPatterns,
1528+
"after-patterns",
1529+
"Only attempt to fold not legal operations "
1530+
"after applying patterns"))};
15131531
};
15141532
} // namespace
15151533

0 commit comments

Comments
 (0)