Skip to content

Commit dc70c3b

Browse files
jpienaargithub-actions[bot]
authored andcommitted
Automerge: [mlir] Enable disabling folding in dialect conversion (#152890)
Previously this only happened post checking if the op is legal, but was done unconditionally post (and before other legalization patterns). Add option to not attempt folding and one to do so as last resort. Did consider but did not add a always attempt to fold option (which would have folded whether or not legal), but removed TODO about it.
2 parents a49ba25 + 6ec0985 commit dc70c3b

File tree

7 files changed

+89
-11
lines changed

7 files changed

+89
-11
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,16 @@ class PDLConversionConfig final {
11611161
// ConversionConfig
11621162
//===----------------------------------------------------------------------===//
11631163

1164+
/// An enum to control folding behavior during dialect conversion.
1165+
enum class DialectConversionFoldingMode {
1166+
/// Never attempt to fold.
1167+
Never,
1168+
/// Only attempt to fold not legal operations before applying patterns.
1169+
BeforePatterns,
1170+
/// Only attempt to fold not legal operations after applying patterns.
1171+
AfterPatterns,
1172+
};
1173+
11641174
/// Dialect conversion configuration.
11651175
struct ConversionConfig {
11661176
/// An optional callback used to notify about match failure diagnostics during
@@ -1243,6 +1253,10 @@ struct ConversionConfig {
12431253
/// your patterns do not trigger any IR rollbacks. For details, see
12441254
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
12451255
bool allowPatternRollback = true;
1256+
1257+
/// The folding mode to use during conversion.
1258+
DialectConversionFoldingMode foldingMode =
1259+
DialectConversionFoldingMode::BeforePatterns;
12461260
};
12471261

12481262
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,15 +2257,17 @@ OperationLegalizer::legalize(Operation *op,
22572257
return success();
22582258
}
22592259

2260-
// If the operation isn't legal, try to fold it in-place.
2261-
// TODO: Should we always try to do this, even if the op is
2262-
// already legal?
2263-
if (succeeded(legalizeWithFold(op, rewriter))) {
2264-
LLVM_DEBUG({
2265-
logSuccess(logger, "operation was folded");
2266-
logger.startLine() << logLineComment;
2267-
});
2268-
return success();
2260+
// If the operation is not legal, try to fold it in-place if the folding mode
2261+
// is 'BeforePatterns'. 'Never' will skip this.
2262+
const ConversionConfig &config = rewriter.getConfig();
2263+
if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2264+
if (succeeded(legalizeWithFold(op, rewriter))) {
2265+
LLVM_DEBUG({
2266+
logSuccess(logger, "operation was folded");
2267+
logger.startLine() << logLineComment;
2268+
});
2269+
return success();
2270+
}
22692271
}
22702272

22712273
// Otherwise, we need to apply a legalization pattern to this operation.
@@ -2277,6 +2279,18 @@ OperationLegalizer::legalize(Operation *op,
22772279
return success();
22782280
}
22792281

2282+
// If the operation can't be legalized via patterns, try to fold it in-place
2283+
// if the folding mode is 'AfterPatterns'.
2284+
if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2285+
if (succeeded(legalizeWithFold(op, rewriter))) {
2286+
LLVM_DEBUG({
2287+
logSuccess(logger, "operation was folded");
2288+
logger.startLine() << logLineComment;
2289+
});
2290+
return success();
2291+
}
2292+
}
2293+
22802294
LLVM_DEBUG({
22812295
logFailure(logger, "no matched legalization pattern");
22822296
logger.startLine() << logLineComment;
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=after-patterns" | FileCheck %s
2+
3+
// CHECK-LABEL: @fold_legalization
4+
func.func @fold_legalization() -> i32 {
5+
// CHECK-NOT: op_in_place_self_fold
6+
// CHECK: 97
7+
%1 = "test.op_in_place_self_fold"() : () -> (i32)
8+
"test.return"(%1) : (i32) -> ()
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=before-patterns" | FileCheck %s
2+
3+
// CHECK-LABEL: @fold_legalization
4+
func.func @fold_legalization() -> i32 {
5+
// CHECK: op_in_place_self_fold
6+
// CHECK-SAME: folded
7+
%1 = "test.op_in_place_self_fold"() : () -> (i32)
8+
"test.return"(%1) : (i32) -> ()
9+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: mlir-opt %s -allow-unregistered-dialect -test-legalize-patterns="test-legalize-folding-mode=never" | FileCheck %s
2+
3+
// CHECK-LABEL: @remove_foldable_op(
4+
func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
5+
// Check that op was not folded.
6+
// CHECK: "test.op_with_region_fold"
7+
%0 = "test.op_with_region_fold"(%arg0) ({
8+
"foo.op_with_region_terminator"() : () -> ()
9+
}) : (i32) -> (i32)
10+
"test.return"(%0) : (i32) -> ()
11+
}
12+

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,8 @@ def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
14981498
let results = (outs I32);
14991499
let hasFolder = 1;
15001500
}
1501+
def : Pat<(TestOpInPlaceSelfFold:$op $_),
1502+
(TestOpConstant ConstantAttr<I32Attr, "97">)>;
15011503

15021504
// Test op that simply returns success.
15031505
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,8 +1507,8 @@ struct TestLegalizePatternDriver
15071507
ConversionTarget target(getContext());
15081508
target.addLegalOp<ModuleOp>();
15091509
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1510-
TerminatorOp, OneRegionOp, TestValidProducerOp,
1511-
TestValidConsumerOp>();
1510+
TerminatorOp, TestOpConstant, OneRegionOp,
1511+
TestValidProducerOp, TestValidConsumerOp>();
15121512
target.addLegalOp(OperationName("test.legal_op", &getContext()));
15131513
target
15141514
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
@@ -1563,6 +1563,7 @@ struct TestLegalizePatternDriver
15631563
DumpNotifications dumpNotifications;
15641564
config.listener = &dumpNotifications;
15651565
config.unlegalizedOps = &unlegalizedOps;
1566+
config.foldingMode = foldingMode;
15661567
if (failed(applyPartialConversion(getOperation(), target,
15671568
std::move(patterns), config))) {
15681569
getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1582,6 +1583,7 @@ struct TestLegalizePatternDriver
15821583

15831584
ConversionConfig config;
15841585
DumpNotifications dumpNotifications;
1586+
config.foldingMode = foldingMode;
15851587
config.listener = &dumpNotifications;
15861588
if (failed(applyFullConversion(getOperation(), target,
15871589
std::move(patterns), config))) {
@@ -1596,6 +1598,7 @@ struct TestLegalizePatternDriver
15961598
// Analyze the convertible operations.
15971599
DenseSet<Operation *> legalizedOps;
15981600
ConversionConfig config;
1601+
config.foldingMode = foldingMode;
15991602
config.legalizableOps = &legalizedOps;
16001603
if (failed(applyAnalysisConversion(getOperation(), target,
16011604
std::move(patterns), config)))
@@ -1616,6 +1619,21 @@ struct TestLegalizePatternDriver
16161619
clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"),
16171620
clEnumValN(ConversionMode::Partial, "partial",
16181621
"Perform a partial conversion"))};
1622+
1623+
Option<DialectConversionFoldingMode> foldingMode{
1624+
*this, "test-legalize-folding-mode",
1625+
llvm::cl::desc("The folding mode to use with the test driver"),
1626+
llvm::cl::init(DialectConversionFoldingMode::BeforePatterns),
1627+
llvm::cl::values(clEnumValN(DialectConversionFoldingMode::Never, "never",
1628+
"Never attempt to fold"),
1629+
clEnumValN(DialectConversionFoldingMode::BeforePatterns,
1630+
"before-patterns",
1631+
"Only attempt to fold not legal operations "
1632+
"before applying patterns"),
1633+
clEnumValN(DialectConversionFoldingMode::AfterPatterns,
1634+
"after-patterns",
1635+
"Only attempt to fold not legal operations "
1636+
"after applying patterns"))};
16191637
};
16201638
} // namespace
16211639

0 commit comments

Comments
 (0)