@@ -538,8 +538,15 @@ class ConversionPattern : public RewritePattern {
538538 ConversionPatternRewriter &rewriter) const {
539539 llvm_unreachable (" unimplemented rewrite" );
540540 }
541+ virtual void rewrite (Operation *op, ArrayRef<ValueRange> operands,
542+ ConversionPatternRewriter &rewriter) const {
543+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
544+ }
541545
542546 // / Hook for derived classes to implement combined matching and rewriting.
547+ // / This overload supports only 1:1 replacements. The 1:N overload is called
548+ // / by the driver. By default, it calls this 1:1 overload or reports a fatal
549+ // / error if 1:N replacements were found.
543550 virtual LogicalResult
544551 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
545552 ConversionPatternRewriter &rewriter) const {
@@ -549,6 +556,14 @@ class ConversionPattern : public RewritePattern {
549556 return success ();
550557 }
551558
559+ // / Hook for derived classes to implement combined matching and rewriting.
560+ // / This overload supports 1:N replacements.
561+ virtual LogicalResult
562+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
563+ ConversionPatternRewriter &rewriter) const {
564+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
565+ }
566+
552567 // / Attempt to match and rewrite the IR root at the specified operation.
553568 LogicalResult matchAndRewrite (Operation *op,
554569 PatternRewriter &rewriter) const final ;
@@ -575,6 +590,15 @@ class ConversionPattern : public RewritePattern {
575590 : RewritePattern(std::forward<Args>(args)...),
576591 typeConverter (&typeConverter) {}
577592
593+ // / Given an array of value ranges, which are the inputs to a 1:N adaptor,
594+ // / try to extract the single value of each range to construct a the inputs
595+ // / for a 1:1 adaptor.
596+ // /
597+ // / This function produces a fatal error if at least one range has 0 or
598+ // / more than 1 value: "pattern 'name' does not support 1:N conversion"
599+ SmallVector<Value>
600+ getOneToOneAdaptorOperands (ArrayRef<ValueRange> operands) const ;
601+
578602protected:
579603 // / An optional type converter for use by this pattern.
580604 const TypeConverter *typeConverter = nullptr ;
@@ -590,6 +614,8 @@ template <typename SourceOp>
590614class OpConversionPattern : public ConversionPattern {
591615public:
592616 using OpAdaptor = typename SourceOp::Adaptor;
617+ using OneToNOpAdaptor =
618+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
593619
594620 OpConversionPattern (MLIRContext *context, PatternBenefit benefit = 1 )
595621 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -608,12 +634,24 @@ class OpConversionPattern : public ConversionPattern {
608634 auto sourceOp = cast<SourceOp>(op);
609635 rewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
610636 }
637+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
638+ ConversionPatternRewriter &rewriter) const final {
639+ auto sourceOp = cast<SourceOp>(op);
640+ rewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp), rewriter);
641+ }
611642 LogicalResult
612643 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
613644 ConversionPatternRewriter &rewriter) const final {
614645 auto sourceOp = cast<SourceOp>(op);
615646 return matchAndRewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
616647 }
648+ LogicalResult
649+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
650+ ConversionPatternRewriter &rewriter) const final {
651+ auto sourceOp = cast<SourceOp>(op);
652+ return matchAndRewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp),
653+ rewriter);
654+ }
617655
618656 // / Rewrite and Match methods that operate on the SourceOp type. These must be
619657 // / overridden by the derived pattern class.
@@ -624,6 +662,12 @@ class OpConversionPattern : public ConversionPattern {
624662 ConversionPatternRewriter &rewriter) const {
625663 llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
626664 }
665+ virtual void rewrite (SourceOp op, OneToNOpAdaptor adaptor,
666+ ConversionPatternRewriter &rewriter) const {
667+ SmallVector<Value> oneToOneOperands =
668+ getOneToOneAdaptorOperands (adaptor.getOperands ());
669+ rewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
670+ }
627671 virtual LogicalResult
628672 matchAndRewrite (SourceOp op, OpAdaptor adaptor,
629673 ConversionPatternRewriter &rewriter) const {
@@ -632,6 +676,13 @@ class OpConversionPattern : public ConversionPattern {
632676 rewrite (op, adaptor, rewriter);
633677 return success ();
634678 }
679+ virtual LogicalResult
680+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
681+ ConversionPatternRewriter &rewriter) const {
682+ SmallVector<Value> oneToOneOperands =
683+ getOneToOneAdaptorOperands (adaptor.getOperands ());
684+ return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
685+ }
635686
636687private:
637688 using ConversionPattern::matchAndRewrite;
@@ -657,18 +708,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
657708 ConversionPatternRewriter &rewriter) const final {
658709 rewrite (cast<SourceOp>(op), operands, rewriter);
659710 }
711+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
712+ ConversionPatternRewriter &rewriter) const final {
713+ rewrite (cast<SourceOp>(op), operands, rewriter);
714+ }
660715 LogicalResult
661716 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
662717 ConversionPatternRewriter &rewriter) const final {
663718 return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
664719 }
720+ LogicalResult
721+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
722+ ConversionPatternRewriter &rewriter) const final {
723+ return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
724+ }
665725
666726 // / Rewrite and Match methods that operate on the SourceOp type. These must be
667727 // / overridden by the derived pattern class.
668728 virtual void rewrite (SourceOp op, ArrayRef<Value> operands,
669729 ConversionPatternRewriter &rewriter) const {
670730 llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
671731 }
732+ virtual void rewrite (SourceOp op, ArrayRef<ValueRange> operands,
733+ ConversionPatternRewriter &rewriter) const {
734+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
735+ }
672736 virtual LogicalResult
673737 matchAndRewrite (SourceOp op, ArrayRef<Value> operands,
674738 ConversionPatternRewriter &rewriter) const {
@@ -677,6 +741,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
677741 rewrite (op, operands, rewriter);
678742 return success ();
679743 }
744+ virtual LogicalResult
745+ matchAndRewrite (SourceOp op, ArrayRef<ValueRange> operands,
746+ ConversionPatternRewriter &rewriter) const {
747+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
748+ }
680749
681750private:
682751 using ConversionPattern::matchAndRewrite;
0 commit comments