@@ -537,8 +537,15 @@ class ConversionPattern : public RewritePattern {
537537 ConversionPatternRewriter &rewriter) const {
538538 llvm_unreachable (" unimplemented rewrite" );
539539 }
540+ virtual void rewrite (Operation *op, ArrayRef<ValueRange> operands,
541+ ConversionPatternRewriter &rewriter) const {
542+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
543+ }
540544
541545 // / Hook for derived classes to implement combined matching and rewriting.
546+ // / This overload supports only 1:1 replacements. The 1:N overload is called
547+ // / by the driver. By default, it calls this 1:1 overload or reports a fatal
548+ // / error if 1:N replacements were found.
542549 virtual LogicalResult
543550 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
544551 ConversionPatternRewriter &rewriter) const {
@@ -548,6 +555,14 @@ class ConversionPattern : public RewritePattern {
548555 return success ();
549556 }
550557
558+ // / Hook for derived classes to implement combined matching and rewriting.
559+ // / This overload supports 1:N replacements.
560+ virtual LogicalResult
561+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
562+ ConversionPatternRewriter &rewriter) const {
563+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
564+ }
565+
551566 // / Attempt to match and rewrite the IR root at the specified operation.
552567 LogicalResult matchAndRewrite (Operation *op,
553568 PatternRewriter &rewriter) const final ;
@@ -574,6 +589,15 @@ class ConversionPattern : public RewritePattern {
574589 : RewritePattern(std::forward<Args>(args)...),
575590 typeConverter (&typeConverter) {}
576591
592+ // / Given an array of value ranges, which are the inputs to a 1:N adaptor,
593+ // / try to extract the single value of each range to construct a the inputs
594+ // / for a 1:1 adaptor.
595+ // /
596+ // / This function produces a fatal error if at least one range has 0 or
597+ // / more than 1 value: "pattern 'name' does not support 1:N conversion"
598+ SmallVector<Value>
599+ getOneToOneAdaptorOperands (ArrayRef<ValueRange> operands) const ;
600+
577601protected:
578602 // / An optional type converter for use by this pattern.
579603 const TypeConverter *typeConverter = nullptr ;
@@ -589,6 +613,8 @@ template <typename SourceOp>
589613class OpConversionPattern : public ConversionPattern {
590614public:
591615 using OpAdaptor = typename SourceOp::Adaptor;
616+ using OneToNOpAdaptor =
617+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
592618
593619 OpConversionPattern (MLIRContext *context, PatternBenefit benefit = 1 )
594620 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +633,24 @@ class OpConversionPattern : public ConversionPattern {
607633 auto sourceOp = cast<SourceOp>(op);
608634 rewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
609635 }
636+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
637+ ConversionPatternRewriter &rewriter) const final {
638+ auto sourceOp = cast<SourceOp>(op);
639+ rewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp), rewriter);
640+ }
610641 LogicalResult
611642 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
612643 ConversionPatternRewriter &rewriter) const final {
613644 auto sourceOp = cast<SourceOp>(op);
614645 return matchAndRewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
615646 }
647+ LogicalResult
648+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
649+ ConversionPatternRewriter &rewriter) const final {
650+ auto sourceOp = cast<SourceOp>(op);
651+ return matchAndRewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp),
652+ rewriter);
653+ }
616654
617655 // / Rewrite and Match methods that operate on the SourceOp type. These must be
618656 // / overridden by the derived pattern class.
@@ -623,6 +661,12 @@ class OpConversionPattern : public ConversionPattern {
623661 ConversionPatternRewriter &rewriter) const {
624662 llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
625663 }
664+ virtual void rewrite (SourceOp op, OneToNOpAdaptor adaptor,
665+ ConversionPatternRewriter &rewriter) const {
666+ SmallVector<Value> oneToOneOperands =
667+ getOneToOneAdaptorOperands (adaptor.getOperands ());
668+ rewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
669+ }
626670 virtual LogicalResult
627671 matchAndRewrite (SourceOp op, OpAdaptor adaptor,
628672 ConversionPatternRewriter &rewriter) const {
@@ -631,6 +675,13 @@ class OpConversionPattern : public ConversionPattern {
631675 rewrite (op, adaptor, rewriter);
632676 return success ();
633677 }
678+ virtual LogicalResult
679+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
680+ ConversionPatternRewriter &rewriter) const {
681+ SmallVector<Value> oneToOneOperands =
682+ getOneToOneAdaptorOperands (adaptor.getOperands ());
683+ return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
684+ }
634685
635686private:
636687 using ConversionPattern::matchAndRewrite;
@@ -656,18 +707,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656707 ConversionPatternRewriter &rewriter) const final {
657708 rewrite (cast<SourceOp>(op), operands, rewriter);
658709 }
710+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
711+ ConversionPatternRewriter &rewriter) const final {
712+ rewrite (cast<SourceOp>(op), operands, rewriter);
713+ }
659714 LogicalResult
660715 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
661716 ConversionPatternRewriter &rewriter) const final {
662717 return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
663718 }
719+ LogicalResult
720+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
721+ ConversionPatternRewriter &rewriter) const final {
722+ return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
723+ }
664724
665725 // / Rewrite and Match methods that operate on the SourceOp type. These must be
666726 // / overridden by the derived pattern class.
667727 virtual void rewrite (SourceOp op, ArrayRef<Value> operands,
668728 ConversionPatternRewriter &rewriter) const {
669729 llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
670730 }
731+ virtual void rewrite (SourceOp op, ArrayRef<ValueRange> operands,
732+ ConversionPatternRewriter &rewriter) const {
733+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
734+ }
671735 virtual LogicalResult
672736 matchAndRewrite (SourceOp op, ArrayRef<Value> operands,
673737 ConversionPatternRewriter &rewriter) const {
@@ -676,6 +740,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676740 rewrite (op, operands, rewriter);
677741 return success ();
678742 }
743+ virtual LogicalResult
744+ matchAndRewrite (SourceOp op, ArrayRef<ValueRange> operands,
745+ ConversionPatternRewriter &rewriter) const {
746+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
747+ }
679748
680749private:
681750 using ConversionPattern::matchAndRewrite;
0 commit comments