@@ -521,8 +521,8 @@ class ConversionPattern : public RewritePattern {
521
521
522
522
// / Hook for derived classes to implement combined matching and rewriting.
523
523
// / This overload supports only 1:1 replacements. The 1:N overload is called
524
- // / by the driver. By default, it calls this 1:1 overload or reports a fatal
525
- // / error if 1:N replacements were found.
524
+ // / by the driver. By default, it calls this 1:1 overload or fails to match
525
+ // / if 1:N replacements were found.
526
526
virtual LogicalResult
527
527
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
528
528
ConversionPatternRewriter &rewriter) const {
@@ -534,7 +534,7 @@ class ConversionPattern : public RewritePattern {
534
534
virtual LogicalResult
535
535
matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
536
536
ConversionPatternRewriter &rewriter) const {
537
- return matchAndRewrite ( op, getOneToOneAdaptorOperands ( operands) , rewriter);
537
+ return dispatchTo1To1 (* this , op, operands, rewriter);
538
538
}
539
539
540
540
// / Attempt to match and rewrite the IR root at the specified operation.
@@ -567,11 +567,26 @@ class ConversionPattern : public RewritePattern {
567
567
// / try to extract the single value of each range to construct a the inputs
568
568
// / for a 1:1 adaptor.
569
569
// /
570
- // / This function produces a fatal error if at least one range has 0 or
571
- // / more than 1 value: "pattern 'name' does not support 1:N conversion"
572
- SmallVector<Value>
570
+ // / Returns failure if at least one range has 0 or more than 1 value.
571
+ FailureOr<SmallVector<Value>>
573
572
getOneToOneAdaptorOperands (ArrayRef<ValueRange> operands) const ;
574
573
574
+ // / Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method
575
+ // / if possible and emit diagnostic with a failure return value otherwise.
576
+ // / 'self' should be '*this' of the derived-pattern and is used to dispatch
577
+ // / to the correct 'matchAndRewrite' method in the derived pattern.
578
+ template <typename SelfPattern, typename SourceOp>
579
+ static LogicalResult dispatchTo1To1 (const SelfPattern &self, SourceOp op,
580
+ ArrayRef<ValueRange> operands,
581
+ ConversionPatternRewriter &rewriter);
582
+
583
+ // / Same as above, but accepts an adaptor as operand.
584
+ template <typename SelfPattern, typename SourceOp>
585
+ static LogicalResult dispatchTo1To1 (
586
+ const SelfPattern &self, SourceOp op,
587
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
588
+ ConversionPatternRewriter &rewriter);
589
+
575
590
protected:
576
591
// / An optional type converter for use by this pattern.
577
592
const TypeConverter *typeConverter = nullptr ;
@@ -620,9 +635,7 @@ class OpConversionPattern : public ConversionPattern {
620
635
virtual LogicalResult
621
636
matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
622
637
ConversionPatternRewriter &rewriter) const {
623
- SmallVector<Value> oneToOneOperands =
624
- getOneToOneAdaptorOperands (adaptor.getOperands ());
625
- return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
638
+ return dispatchTo1To1 (*this , op, adaptor, rewriter);
626
639
}
627
640
628
641
private:
@@ -666,7 +679,7 @@ class OpInterfaceConversionPattern : public ConversionPattern {
666
679
virtual LogicalResult
667
680
matchAndRewrite (SourceOp op, ArrayRef<ValueRange> operands,
668
681
ConversionPatternRewriter &rewriter) const {
669
- return matchAndRewrite ( op, getOneToOneAdaptorOperands ( operands) , rewriter);
682
+ return dispatchTo1To1 (* this , op, operands, rewriter);
670
683
}
671
684
672
685
private:
@@ -865,6 +878,35 @@ class ConversionPatternRewriter final : public PatternRewriter {
865
878
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
866
879
};
867
880
881
+ template <typename SelfPattern, typename SourceOp>
882
+ LogicalResult
883
+ ConversionPattern::dispatchTo1To1 (const SelfPattern &self, SourceOp op,
884
+ ArrayRef<ValueRange> operands,
885
+ ConversionPatternRewriter &rewriter) {
886
+ FailureOr<SmallVector<Value>> oneToOneOperands =
887
+ self.getOneToOneAdaptorOperands (operands);
888
+ if (failed (oneToOneOperands))
889
+ return rewriter.notifyMatchFailure (op,
890
+ " pattern '" + self.getDebugName () +
891
+ " ' does not support 1:N conversion" );
892
+ return self.matchAndRewrite (op, *oneToOneOperands, rewriter);
893
+ }
894
+
895
+ template <typename SelfPattern, typename SourceOp>
896
+ LogicalResult ConversionPattern::dispatchTo1To1 (
897
+ const SelfPattern &self, SourceOp op,
898
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
899
+ ConversionPatternRewriter &rewriter) {
900
+ FailureOr<SmallVector<Value>> oneToOneOperands =
901
+ self.getOneToOneAdaptorOperands (adaptor.getOperands ());
902
+ if (failed (oneToOneOperands))
903
+ return rewriter.notifyMatchFailure (op,
904
+ " pattern '" + self.getDebugName () +
905
+ " ' does not support 1:N conversion" );
906
+ return self.matchAndRewrite (
907
+ op, typename SourceOp::Adaptor (*oneToOneOperands, adaptor), rewriter);
908
+ }
909
+
868
910
// ===----------------------------------------------------------------------===//
869
911
// ConversionTarget
870
912
// ===----------------------------------------------------------------------===//
0 commit comments