Skip to content

Commit 63eac17

Browse files
[mlir] Dialect Conversion: Add support for post-order legalization order
1 parent 70ff2c9 commit 63eac17

File tree

5 files changed

+133
-4
lines changed

5 files changed

+133
-4
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,20 @@ class ConversionPatternRewriter final : public PatternRewriter {
981981
/// Return a reference to the internal implementation.
982982
detail::ConversionPatternRewriterImpl &getImpl();
983983

984+
/// Attempt to legalize the given operation. This can be used within
985+
/// conversion patterns to change the default pre-order legalization order.
986+
/// Returns "success" if the operation was legalized, "failure" otherwise.
987+
LogicalResult legalize(Operation *op);
988+
989+
/// Attempt to legalize the given region. This can be used within
990+
/// conversion patterns to change the default pre-order legalization order.
991+
/// Returns "success" if the region was legalized, "failure" otherwise.
992+
///
993+
/// If the current pattern runs with a type converter, the entry block
994+
/// signature will be converted before legalizing the operations in the
995+
/// region.
996+
LogicalResult legalize(Region *r);
997+
984998
private:
985999
// Allow OperationConverter to construct new rewriters.
9861000
friend struct OperationConverter;

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -862,8 +862,11 @@ static bool hasRewrite(R &&rewrites, Block *block) {
862862
//===----------------------------------------------------------------------===//
863863
// ConversionPatternRewriterImpl
864864
//===----------------------------------------------------------------------===//
865+
865866
namespace mlir {
866867
namespace detail {
868+
class OperationLegalizer;
869+
867870
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
868871
explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
869872
const ConversionConfig &config)
@@ -915,6 +918,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
915918
/// Return "true" if the given operation was replaced or erased.
916919
bool wasOpReplaced(Operation *op) const;
917920

921+
/// Set the operation legalizer to use for recursive legalization.
922+
void setOperationLegalizer(OperationLegalizer *legalizer) {
923+
opLegalizer = legalizer;
924+
}
925+
918926
/// Lookup the most recently mapped values with the desired types in the
919927
/// mapping, taking into account only replacements. Perform a best-effort
920928
/// search for existing materializations with the desired types.
@@ -1121,6 +1129,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11211129
/// converting the arguments of blocks within that region.
11221130
DenseMap<Region *, const TypeConverter *> regionToConverter;
11231131

1132+
/// The operation legalizer to use for recursive legalization. This is set by
1133+
/// the OperationConverter when the rewriter is created.
1134+
OperationLegalizer *opLegalizer = nullptr;
1135+
11241136
/// Dialect conversion configuration.
11251137
const ConversionConfig &config;
11261138

@@ -2357,7 +2369,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
23572369
// OperationLegalizer
23582370
//===----------------------------------------------------------------------===//
23592371

2360-
namespace {
2372+
namespace mlir::detail {
23612373
/// A set of rewrite patterns that can be used to legalize a given operation.
23622374
using LegalizationPatterns = SmallVector<const Pattern *, 1>;
23632375

@@ -2454,7 +2466,7 @@ class OperationLegalizer {
24542466
/// The pattern applicator to use for conversions.
24552467
PatternApplicator applicator;
24562468
};
2457-
} // namespace
2469+
} // namespace mlir::detail
24582470

24592471
OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
24602472
const ConversionTarget &targetInfo,
@@ -2854,6 +2866,41 @@ LogicalResult OperationLegalizer::legalizePatternRootUpdates(
28542866
return success();
28552867
}
28562868

2869+
LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
2870+
return impl->opLegalizer->legalize(op);
2871+
}
2872+
2873+
LogicalResult ConversionPatternRewriter::legalize(Region *r) {
2874+
// Fast path: If the region is empty, there is nothing to legalize.
2875+
if (r->empty())
2876+
return success();
2877+
2878+
// Gather a list of all operations to legalize. This is done before
2879+
// converting the entry block signature because unrealized_conversion_cast
2880+
// ops should not be included.
2881+
SmallVector<Operation *> ops;
2882+
for (Block &b : *r)
2883+
for (Operation &op : b)
2884+
ops.push_back(&op);
2885+
2886+
// If the current pattern runs with a type converter, convert the entry block
2887+
// signature.
2888+
if (const TypeConverter *converter = impl->currentTypeConverter) {
2889+
std::optional<TypeConverter::SignatureConversion> conversion =
2890+
converter->convertBlockSignature(&r->front());
2891+
if (!conversion)
2892+
return failure();
2893+
applySignatureConversion(&r->front(), *conversion, converter);
2894+
}
2895+
2896+
// Legalize all operations in the region.
2897+
for (Operation *op : ops)
2898+
if (failed(legalize(op)))
2899+
return failure();
2900+
2901+
return success();
2902+
}
2903+
28572904
//===----------------------------------------------------------------------===//
28582905
// Cost Model
28592906
//===----------------------------------------------------------------------===//
@@ -3218,7 +3265,10 @@ struct OperationConverter {
32183265
const ConversionConfig &config,
32193266
OpConversionMode mode)
32203267
: rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
3221-
mode(mode) {}
3268+
mode(mode) {
3269+
// Set the legalizer in the rewriter so patterns can recursively legalize.
3270+
rewriter.getImpl().setOperationLegalizer(&opLegalizer);
3271+
}
32223272

32233273
/// Converts the given operations to the conversion target.
32243274
LogicalResult convertOperations(ArrayRef<Operation *> ops);

mlir/test/Transforms/test-legalizer-rollback.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
163163
"test.return"(%0) : (i32) -> ()
164164
}
165165
}
166+
167+
// -----
168+
169+
// CHECK-LABEL: func @test_failed_preorder_legalization
170+
// CHECK: "test.post_order_legalization"() ({
171+
// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32
172+
// CHECK: "test.return"(%[[r]]) : (i32) -> ()
173+
// CHECK: }) : () -> ()
174+
// expected-remark @+1 {{applyPartialConversion failed}}
175+
module {
176+
func.func @test_failed_preorder_legalization() {
177+
// expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
178+
"test.post_order_legalization"() ({
179+
%0 = "test.illegal_op_g"() : () -> (i32)
180+
"test.return"(%0) : (i32) -> ()
181+
}) : () -> ()
182+
return
183+
}
184+
}

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,29 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
448448
"test.type_consumer"(%arg0) : (f16) -> ()
449449
"test.return"() : () -> ()
450450
}
451+
452+
// -----
453+
454+
// The region of "test.post_order_legalization" is converted before the op.
455+
456+
// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
457+
// CHECK: notifyOperationInserted: test.invalid
458+
// CHECK: notifyBlockErased
459+
// CHECK: notifyOperationInserted: test.valid, was unlinked
460+
// CHECK: notifyOperationReplaced: test.invalid
461+
// CHECK: notifyOperationErased: test.invalid
462+
// CHECK: notifyOperationModified: test.post_order_legalization
463+
464+
// CHECK-LABEL: func @test_preorder_legalization
465+
// CHECK: "test.post_order_legalization"() ({
466+
// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
467+
// CHECK: "test.valid"(%[[arg0]]) : (f64) -> ()
468+
// CHECK: }) {is_legal} : () -> ()
469+
func.func @test_preorder_legalization() {
470+
"test.post_order_legalization"() ({
471+
^bb0(%arg0: i64):
472+
"test.invalid"(%arg0) : (i64) -> ()
473+
}) : () -> ()
474+
// expected-remark @+1 {{is not legalizable}}
475+
return
476+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,22 @@ class TestTypeConsumerOpPattern
14181418
}
14191419
};
14201420

1421+
class TestPostOrderLegalization : public ConversionPattern {
1422+
public:
1423+
TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
1424+
: ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
1425+
LogicalResult
1426+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1427+
ConversionPatternRewriter &rewriter) const final {
1428+
for (Region &r : op->getRegions())
1429+
if (failed(rewriter.legalize(&r)))
1430+
return failure();
1431+
rewriter.modifyOpInPlace(
1432+
op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
1433+
return success();
1434+
}
1435+
};
1436+
14211437
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
14221438
/// function is just to trigger compiler errors. It is never executed.
14231439
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
15321548
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
15331549
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
15341550
TestValueReplace, TestReplaceWithValidConsumer,
1535-
TestTypeConsumerOpPattern>(&getContext(), converter);
1551+
TestTypeConsumerOpPattern, TestPostOrderLegalization>(
1552+
&getContext(), converter);
15361553
patterns.add<TestConvertBlockArgs>(converter, &getContext());
15371554
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
15381555
converter);
@@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
15601577
target.addDynamicallyLegalOp(
15611578
OperationName("test.value_replace", &getContext()),
15621579
[](Operation *op) { return op->hasAttr("is_legal"); });
1580+
target.addDynamicallyLegalOp(
1581+
OperationName("test.post_order_legalization", &getContext()),
1582+
[](Operation *op) { return op->hasAttr("is_legal"); });
15631583

15641584
// TestCreateUnregisteredOp creates `arith.constant` operation,
15651585
// which was not added to target intentionally to test

0 commit comments

Comments
 (0)