Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ class TypeConverter {
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
template <typename TargetType> TargetType convertType(Type t) const {
template <typename TargetType>
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}

Expand Down Expand Up @@ -1118,6 +1119,14 @@ struct ConversionConfig {
// already been modified) and iterators into past IR state cannot be
// represented at the moment.
RewriterBase::Listener *listener = nullptr;

/// If set to "true", the dialect conversion driver attempts to fold
/// operations throughout the conversion. This is problematic because op
/// folders may assume that the IR is in a valid state at the beginning of
/// the folding process. However, the dialect conversion does not guarantee
/// that because some IR modifications are delayed until the end of the
/// conversion.
bool foldOps = true;
Comment on lines +1124 to +1129
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether there should be either a TODO such as "change to false in the future" (if we want to take that route) or whether the comment should note that it is true for historic reasons.

Looks funny that the majority of the paragraph discourages using the options but we default to it being true

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go further: if we consider this unsafe, we should just deprecate this mode entirely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do mean setting this to false by default or adding a comment that this is deprecated (or both)? We have at least one test case in the test dialect that tests the folding.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should be setting this to false by default, folks who are broken can set it back to true, but we also document it as deprecated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are broken with foldOps = false:

Failed Tests (8):
  MLIR :: Conversion/AffineToStandard/lower-affine.mlir
  MLIR :: Conversion/FuncToLLVM/calling-convention.mlir
  MLIR :: Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
  MLIR :: Conversion/ShapeToStandard/shape-to-standard.mlir
  MLIR :: Conversion/VectorToLLVM/vector-to-llvm.mlir
  MLIR :: Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
  MLIR :: Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
  MLIR :: Dialect/Vector/linearize.mlir

They are mostly FileCheck failures, but vector-to-llvm.mlir is actually broken. I'm busy with other stuff right now, so it might take a while. (Or if someone else wants to take this over, feel free to.)

};

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,7 +2030,7 @@ OperationLegalizer::legalize(Operation *op,
// If the operation isn't legal, try to fold it in-place.
// TODO: Should we always try to do this, even if the op is
// already legal?
if (succeeded(legalizeWithFold(op, rewriter))) {
if (config.foldOps && succeeded(legalizeWithFold(op, rewriter))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Transforms/test-legalizer-analysis.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s
// expected-remark@-2 {{op 'builtin.module' is legalizable}}

// expected-remark@+1 {{op 'func.func' is legalizable}}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Transforms/test-legalizer-full.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: func @multi_level_mapping
func.func @multi_level_mapping() {
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Transforms/test-legalizer-no-fold.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="fold-ops=0" %s | FileCheck %s

// CHECK-LABEL: @remove_foldable_op(
func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
// Check that op was not folded.
// CHECK: "test.op_with_region_fold"
%0 = "test.op_with_region_fold"(%arg0) ({
"foo.op_with_region_terminator"() : () -> ()
}) : (i32) -> (i32)
"test.return"(%0) : (i32) -> ()
}
42 changes: 23 additions & 19 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,9 @@ struct TestLegalizePatternDriver
/// The mode of conversion to use with the driver.
enum class ConversionMode { Analysis, Full, Partial };

TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
TestLegalizePatternDriver() = default;
TestLegalizePatternDriver(const TestLegalizePatternDriver &other)
: PassWrapper(other) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, test::TestDialect>();
Expand Down Expand Up @@ -1179,6 +1181,7 @@ struct TestLegalizePatternDriver
DumpNotifications dumpNotifications;
config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
config.foldOps = foldOps;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
Expand All @@ -1197,6 +1200,7 @@ struct TestLegalizePatternDriver
});

ConversionConfig config;
config.foldOps = foldOps;
DumpNotifications dumpNotifications;
config.listener = &dumpNotifications;
if (failed(applyFullConversion(getOperation(), target,
Expand All @@ -1212,6 +1216,7 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
ConversionConfig config;
config.foldOps = foldOps;
config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
std::move(patterns), config)))
Expand All @@ -1222,24 +1227,25 @@ struct TestLegalizePatternDriver
op->emitRemark() << "op '" << op->getName() << "' is legalizable";
}

/// The mode of conversion to use.
ConversionMode mode;
Option<bool> foldOps{
*this, "fold-ops",
llvm::cl::desc("Fold ops throughout the conversion process"),
llvm::cl::init(true)};

Option<TestLegalizePatternDriver::ConversionMode> mode{
*this, "legalize-mode",
llvm::cl::desc("The legalization mode to use with the test driver"),
llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
llvm::cl::values(
clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
"analysis", "Perform an analysis conversion"),
clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
"Perform a full conversion"),
clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
"partial", "Perform a partial conversion"))};
};
} // namespace

static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
legalizerConversionMode(
"test-legalize-mode",
llvm::cl::desc("The legalization mode to use with the test driver"),
llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
llvm::cl::values(
clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
"analysis", "Perform an analysis conversion"),
clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
"Perform a full conversion"),
clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
"partial", "Perform a partial conversion")));

//===----------------------------------------------------------------------===//
// ConversionPatternRewriter::getRemappedValue testing. This method is used
// to get the remapped value of an original value that was replaced using
Expand Down Expand Up @@ -1909,9 +1915,7 @@ void registerPatternsTestPass() {
PassRegistration<TestPatternDriver>();
PassRegistration<TestStrictPatternDriver>();

PassRegistration<TestLegalizePatternDriver>([] {
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
});
PassRegistration<TestLegalizePatternDriver>();

PassRegistration<TestRemappedValue>();

Expand Down