Skip to content

Commit 8582025

Browse files
authored
[mlir][Transforms] Turn 1:N -> 1:1 dispatch fatal error into match failure (#153605)
Prior to this PR, the default behaviour of a conversion pattern which receives operands of a 1:N is to abort the compilation. This has historically been useful when the 1:N type conversion got merged into the dialect conversion as it allowed us to easily find patterns that should be capable of handling 1:N type conversions but didn't. However, this behaviour has the disadvantage of being non-composable: While the pattern in question cannot handle the 1:N type conversion, another pattern part of the set might, but doesn't get the chance as compilation is aborted. This PR fixes this behaviour by failing to match and instead of aborting, giving other patterns the chance to legalize an op. The implementation uses a reusable function called `dispatchTo1To1` to allow derived conversion patterns to also implement the behaviour.
1 parent 6b16a27 commit 8582025

File tree

6 files changed

+99
-23
lines changed

6 files changed

+99
-23
lines changed

flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,7 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
237237
virtual llvm::LogicalResult
238238
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
239239
mlir::ConversionPatternRewriter &rewriter) const {
240-
llvm::SmallVector<mlir::Value> oneToOneOperands =
241-
getOneToOneAdaptorOperands(adaptor.getOperands());
242-
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
240+
return dispatchTo1To1(*this, op, adaptor, rewriter);
243241
}
244242

245243
private:

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,7 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
243243
virtual LogicalResult
244244
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
245245
ConversionPatternRewriter &rewriter) const {
246-
SmallVector<Value> oneToOneOperands =
247-
getOneToOneAdaptorOperands(adaptor.getOperands());
248-
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
246+
return dispatchTo1To1(*this, op, adaptor, rewriter);
249247
}
250248

251249
private:
@@ -286,7 +284,7 @@ class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern {
286284
virtual LogicalResult
287285
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
288286
ConversionPatternRewriter &rewriter) const {
289-
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
287+
return dispatchTo1To1(*this, op, operands, rewriter);
290288
}
291289

292290
private:

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,8 @@ class ConversionPattern : public RewritePattern {
521521

522522
/// Hook for derived classes to implement combined matching and rewriting.
523523
/// This overload supports only 1:1 replacements. The 1:N overload is called
524-
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
525-
/// error if 1:N replacements were found.
524+
/// by the driver. By default, it calls this 1:1 overload or fails to match
525+
/// if 1:N replacements were found.
526526
virtual LogicalResult
527527
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
528528
ConversionPatternRewriter &rewriter) const {
@@ -534,7 +534,7 @@ class ConversionPattern : public RewritePattern {
534534
virtual LogicalResult
535535
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
536536
ConversionPatternRewriter &rewriter) const {
537-
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
537+
return dispatchTo1To1(*this, op, operands, rewriter);
538538
}
539539

540540
/// Attempt to match and rewrite the IR root at the specified operation.
@@ -567,11 +567,26 @@ class ConversionPattern : public RewritePattern {
567567
/// try to extract the single value of each range to construct a the inputs
568568
/// for a 1:1 adaptor.
569569
///
570-
/// This function produces a fatal error if at least one range has 0 or
571-
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
572-
SmallVector<Value>
570+
/// Returns failure if at least one range has 0 or more than 1 value.
571+
FailureOr<SmallVector<Value>>
573572
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
574573

574+
/// Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method
575+
/// if possible and emit diagnostic with a failure return value otherwise.
576+
/// 'self' should be '*this' of the derived-pattern and is used to dispatch
577+
/// to the correct 'matchAndRewrite' method in the derived pattern.
578+
template <typename SelfPattern, typename SourceOp>
579+
static LogicalResult dispatchTo1To1(const SelfPattern &self, SourceOp op,
580+
ArrayRef<ValueRange> operands,
581+
ConversionPatternRewriter &rewriter);
582+
583+
/// Same as above, but accepts an adaptor as operand.
584+
template <typename SelfPattern, typename SourceOp>
585+
static LogicalResult dispatchTo1To1(
586+
const SelfPattern &self, SourceOp op,
587+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
588+
ConversionPatternRewriter &rewriter);
589+
575590
protected:
576591
/// An optional type converter for use by this pattern.
577592
const TypeConverter *typeConverter = nullptr;
@@ -620,9 +635,7 @@ class OpConversionPattern : public ConversionPattern {
620635
virtual LogicalResult
621636
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
622637
ConversionPatternRewriter &rewriter) const {
623-
SmallVector<Value> oneToOneOperands =
624-
getOneToOneAdaptorOperands(adaptor.getOperands());
625-
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
638+
return dispatchTo1To1(*this, op, adaptor, rewriter);
626639
}
627640

628641
private:
@@ -666,7 +679,7 @@ class OpInterfaceConversionPattern : public ConversionPattern {
666679
virtual LogicalResult
667680
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
668681
ConversionPatternRewriter &rewriter) const {
669-
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
682+
return dispatchTo1To1(*this, op, operands, rewriter);
670683
}
671684

672685
private:
@@ -865,6 +878,35 @@ class ConversionPatternRewriter final : public PatternRewriter {
865878
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
866879
};
867880

881+
template <typename SelfPattern, typename SourceOp>
882+
LogicalResult
883+
ConversionPattern::dispatchTo1To1(const SelfPattern &self, SourceOp op,
884+
ArrayRef<ValueRange> operands,
885+
ConversionPatternRewriter &rewriter) {
886+
FailureOr<SmallVector<Value>> oneToOneOperands =
887+
self.getOneToOneAdaptorOperands(operands);
888+
if (failed(oneToOneOperands))
889+
return rewriter.notifyMatchFailure(op,
890+
"pattern '" + self.getDebugName() +
891+
"' does not support 1:N conversion");
892+
return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
893+
}
894+
895+
template <typename SelfPattern, typename SourceOp>
896+
LogicalResult ConversionPattern::dispatchTo1To1(
897+
const SelfPattern &self, SourceOp op,
898+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
899+
ConversionPatternRewriter &rewriter) {
900+
FailureOr<SmallVector<Value>> oneToOneOperands =
901+
self.getOneToOneAdaptorOperands(adaptor.getOperands());
902+
if (failed(oneToOneOperands))
903+
return rewriter.notifyMatchFailure(op,
904+
"pattern '" + self.getDebugName() +
905+
"' does not support 1:N conversion");
906+
return self.matchAndRewrite(
907+
op, typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
908+
}
909+
868910
//===----------------------------------------------------------------------===//
869911
// ConversionTarget
870912
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,17 +2244,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
22442244
// ConversionPattern
22452245
//===----------------------------------------------------------------------===//
22462246

2247-
SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
2247+
FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
22482248
ArrayRef<ValueRange> operands) const {
22492249
SmallVector<Value> oneToOneOperands;
22502250
oneToOneOperands.reserve(operands.size());
22512251
for (ValueRange operand : operands) {
22522252
if (operand.size() != 1)
2253-
llvm::report_fatal_error("pattern '" + getDebugName() +
2254-
"' does not support 1:N conversion");
2253+
return failure();
2254+
22552255
oneToOneOperands.push_back(operand.front());
22562256
}
2257-
return oneToOneOperands;
2257+
return std::move(oneToOneOperands);
22582258
}
22592259

22602260
LogicalResult

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,24 @@ func.func @test_lookup_without_converter() {
439439
// expected-remark@+1 {{op 'func.return' is not legalizable}}
440440
return
441441
}
442+
443+
// -----
444+
// expected-remark@-1 {{applyPartialConversion failed}}
445+
446+
func.func @test_skip_1to1_pattern(%arg0: f32) {
447+
// expected-error@+1 {{failed to legalize operation 'test.type_consumer'}}
448+
"test.type_consumer"(%arg0) : (f32) -> ()
449+
return
450+
}
451+
452+
// -----
453+
454+
// Demonstrate that the pattern generally works, but only for 1:1 type
455+
// conversions.
456+
457+
// CHECK-LABEL: @test_working_1to1_pattern(
458+
func.func @test_working_1to1_pattern(%arg0: f16) {
459+
// CHECK-NEXT: "test.return"() : () -> ()
460+
"test.type_consumer"(%arg0) : (f16) -> ()
461+
"test.return"() : () -> ()
462+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,23 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
13861386
}
13871387
};
13881388

1389+
/// Pattern that erases 'test.type_consumers' iff the input operand is the
1390+
/// result of a 1:1 type conversion.
1391+
/// Used to test correct skipping of 1:1 patterns in the 1:N case.
1392+
class TestTypeConsumerOpPattern
1393+
: public OpConversionPattern<TestTypeConsumerOp> {
1394+
public:
1395+
TestTypeConsumerOpPattern(MLIRContext *ctx, const TypeConverter &converter)
1396+
: OpConversionPattern<TestTypeConsumerOp>(converter, ctx) {}
1397+
1398+
LogicalResult
1399+
matchAndRewrite(TestTypeConsumerOp op, OpAdaptor operands,
1400+
ConversionPatternRewriter &rewriter) const final {
1401+
rewriter.eraseOp(op);
1402+
return success();
1403+
}
1404+
};
1405+
13891406
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
13901407
/// function is just to trigger compiler errors. It is never executed.
13911408
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1497,8 +1514,8 @@ struct TestLegalizePatternDriver
14971514
TestRepetitive1ToNConsumer>(&getContext());
14981515
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
14991516
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
1500-
TestBlockArgReplace, TestReplaceWithValidConsumer>(
1501-
&getContext(), converter);
1517+
TestBlockArgReplace, TestReplaceWithValidConsumer,
1518+
TestTypeConsumerOpPattern>(&getContext(), converter);
15021519
patterns.add<TestConvertBlockArgs>(converter, &getContext());
15031520
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
15041521
converter);

0 commit comments

Comments
 (0)