Skip to content

[mlir] Enable disabling folding in dialect conversion #152890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,16 @@ class PDLConversionConfig final {
// ConversionConfig
//===----------------------------------------------------------------------===//

/// An enum to control folding behavior during dialect conversion.
enum class DialectConversionFoldingMode {
/// Never attempt to fold.
Never,
/// Only attempt to fold not legal operations before applying patterns.
BeforePatterns,
/// Only attempt to fold not legal operations after applying patterns.
AfterPatterns,
};

/// Dialect conversion configuration.
struct ConversionConfig {
/// An optional callback used to notify about match failure diagnostics during
Expand Down Expand Up @@ -1240,6 +1250,10 @@ struct ConversionConfig {
/// your patterns do not trigger any IR rollbacks. For details, see
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
bool allowPatternRollback = true;

/// The folding mode to use during conversion.
DialectConversionFoldingMode foldingMode =
DialectConversionFoldingMode::BeforePatterns;
};

//===----------------------------------------------------------------------===//
Expand Down
31 changes: 22 additions & 9 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2197,15 +2197,16 @@ OperationLegalizer::legalize(Operation *op,
return success();
}

// If the operation isn't legal, try to fold it in-place.
// TODO: Should we always try to do this, even if the op is
// already legal?
if (succeeded(legalizeWithFold(op, rewriter))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
});
return success();
// If the operation is not legal, try to fold it in-place if the folding mode
// is 'BeforePatterns'. 'Never' will skip this.
if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
if (succeeded(legalizeWithFold(op, rewriter))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
});
return success();
}
}

// Otherwise, we need to apply a legalization pattern to this operation.
Expand All @@ -2217,6 +2218,18 @@ OperationLegalizer::legalize(Operation *op,
return success();
}

// If the operation can't be legalized via patterns, try to fold it in-place
// if the folding mode is 'AfterPatterns'.
if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
if (succeeded(legalizeWithFold(op, rewriter))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
});
return success();
}
}

LLVM_DEBUG({
logFailure(logger, "no matched legalization pattern");
logger.startLine() << logLineComment;
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Transforms/test-legalizer-fold-after.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=after-patterns" | FileCheck %s

// CHECK-LABEL: @fold_legalization
func.func @fold_legalization() -> i32 {
// CHECK-NOT: op_in_place_self_fold
// CHECK: 97
%1 = "test.op_in_place_self_fold"() : () -> (i32)
"test.return"(%1) : (i32) -> ()
}
9 changes: 9 additions & 0 deletions mlir/test/Transforms/test-legalizer-fold-before.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=before-patterns" | FileCheck %s

// CHECK-LABEL: @fold_legalization
func.func @fold_legalization() -> i32 {
// CHECK: op_in_place_self_fold
// CHECK-SAME: folded
%1 = "test.op_in_place_self_fold"() : () -> (i32)
"test.return"(%1) : (i32) -> ()
}
11 changes: 11 additions & 0 deletions mlir/test/Transforms/test-legalizer-no-fold.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -test-legalize-patterns="test-legalize-folding-mode=never" | FileCheck %s

// CHECK-LABEL: @remove_foldable_op(
func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
// Check that op was not folded.
// CHECK: "test.op_with_region_fold"
%0 = "test.op_with_region_fold"(%arg0) ({
"foo.op_with_region_terminator"() : () -> ()
}) : (i32) -> (i32)
"test.return"(%0) : (i32) -> ()
}
2 changes: 2 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,8 @@ def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
let results = (outs I32);
let hasFolder = 1;
}
def : Pat<(TestOpInPlaceSelfFold:$op $_),
(TestOpConstant ConstantAttr<I32Attr, "97">)>;

// Test op that simply returns success.
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
Expand Down
20 changes: 19 additions & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,7 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
TerminatorOp, OneRegionOp>();
TerminatorOp, OneRegionOp, TestOpConstant>();
target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
Expand Down Expand Up @@ -1457,6 +1457,7 @@ struct TestLegalizePatternDriver
DumpNotifications dumpNotifications;
config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
config.foldingMode = foldingMode;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
Expand All @@ -1476,6 +1477,7 @@ struct TestLegalizePatternDriver

ConversionConfig config;
DumpNotifications dumpNotifications;
config.foldingMode = foldingMode;
config.listener = &dumpNotifications;
if (failed(applyFullConversion(getOperation(), target,
std::move(patterns), config))) {
Expand All @@ -1490,6 +1492,7 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
ConversionConfig config;
config.foldingMode = foldingMode;
config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
std::move(patterns), config)))
Expand All @@ -1510,6 +1513,21 @@ struct TestLegalizePatternDriver
clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"),
clEnumValN(ConversionMode::Partial, "partial",
"Perform a partial conversion"))};

Option<DialectConversionFoldingMode> foldingMode{
*this, "test-legalize-folding-mode",
llvm::cl::desc("The folding mode to use with the test driver"),
llvm::cl::init(DialectConversionFoldingMode::BeforePatterns),
llvm::cl::values(clEnumValN(DialectConversionFoldingMode::Never, "never",
"Never attempt to fold"),
clEnumValN(DialectConversionFoldingMode::BeforePatterns,
"before-patterns",
"Only attempt to fold not legal operations "
"before applying patterns"),
clEnumValN(DialectConversionFoldingMode::AfterPatterns,
"after-patterns",
"Only attempt to fold not legal operations "
"after applying patterns"))};
};
} // namespace

Expand Down