@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
537537 ConversionPatternRewriter &rewriter) const {
538538 llvm_unreachable (" unimplemented rewrite" );
539539 }
540+ virtual void rewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
541+ ConversionPatternRewriter &rewriter) const {
542+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
543+ }
540544
541545 // / Hook for derived classes to implement combined matching and rewriting.
542546 virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
547551 rewrite (op, operands, rewriter);
548552 return success ();
549553 }
554+ virtual LogicalResult
555+ matchAndRewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
556+ ConversionPatternRewriter &rewriter) const {
557+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
558+ }
550559
551560 // / Attempt to match and rewrite the IR root at the specified operation.
552561 LogicalResult matchAndRewrite (Operation *op,
@@ -574,6 +583,9 @@ class ConversionPattern : public RewritePattern {
574583 : RewritePattern(std::forward<Args>(args)...),
575584 typeConverter (&typeConverter) {}
576585
586+ SmallVector<Value>
587+ getOneToOneAdaptorOperands (ArrayRef<ArrayRef<Value>> operands) const ;
588+
577589protected:
578590 // / An optional type converter for use by this pattern.
579591 const TypeConverter *typeConverter = nullptr ;
@@ -589,6 +601,8 @@ template <typename SourceOp>
589601class OpConversionPattern : public ConversionPattern {
590602public:
591603 using OpAdaptor = typename SourceOp::Adaptor;
604+ using OneToNOpAdaptor =
605+ typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
592606
593607 OpConversionPattern (MLIRContext *context, PatternBenefit benefit = 1 )
594608 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +621,24 @@ class OpConversionPattern : public ConversionPattern {
607621 auto sourceOp = cast<SourceOp>(op);
608622 rewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
609623 }
624+ void rewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
625+ ConversionPatternRewriter &rewriter) const final {
626+ auto sourceOp = cast<SourceOp>(op);
627+ rewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp), rewriter);
628+ }
610629 LogicalResult
611630 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
612631 ConversionPatternRewriter &rewriter) const final {
613632 auto sourceOp = cast<SourceOp>(op);
614633 return matchAndRewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
615634 }
635+ LogicalResult
636+ matchAndRewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
637+ ConversionPatternRewriter &rewriter) const final {
638+ auto sourceOp = cast<SourceOp>(op);
639+ return matchAndRewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp),
640+ rewriter);
641+ }
616642
617643 // / Rewrite and Match methods that operate on the SourceOp type. These must be
618644 // / overridden by the derived pattern class.
@@ -623,6 +649,12 @@ class OpConversionPattern : public ConversionPattern {
623649 ConversionPatternRewriter &rewriter) const {
624650 llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
625651 }
652+ virtual void rewrite (SourceOp op, OneToNOpAdaptor adaptor,
653+ ConversionPatternRewriter &rewriter) const {
654+ SmallVector<Value> oneToOneOperands =
655+ getOneToOneAdaptorOperands (adaptor.getOperands ());
656+ rewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
657+ }
626658 virtual LogicalResult
627659 matchAndRewrite (SourceOp op, OpAdaptor adaptor,
628660 ConversionPatternRewriter &rewriter) const {
@@ -631,6 +663,13 @@ class OpConversionPattern : public ConversionPattern {
631663 rewrite (op, adaptor, rewriter);
632664 return success ();
633665 }
666+ virtual LogicalResult
667+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
668+ ConversionPatternRewriter &rewriter) const {
669+ SmallVector<Value> oneToOneOperands =
670+ getOneToOneAdaptorOperands (adaptor.getOperands ());
671+ return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
672+ }
634673
635674private:
636675 using ConversionPattern::matchAndRewrite;
@@ -656,18 +695,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656695 ConversionPatternRewriter &rewriter) const final {
657696 rewrite (cast<SourceOp>(op), operands, rewriter);
658697 }
698+ void rewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
699+ ConversionPatternRewriter &rewriter) const final {
700+ rewrite (cast<SourceOp>(op), operands, rewriter);
701+ }
659702 LogicalResult
660703 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
661704 ConversionPatternRewriter &rewriter) const final {
662705 return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
663706 }
707+ LogicalResult
708+ matchAndRewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
709+ ConversionPatternRewriter &rewriter) const final {
710+ return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
711+ }
664712
665713 // / Rewrite and Match methods that operate on the SourceOp type. These must be
666714 // / overridden by the derived pattern class.
667715 virtual void rewrite (SourceOp op, ArrayRef<Value> operands,
668716 ConversionPatternRewriter &rewriter) const {
669717 llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
670718 }
719+ virtual void rewrite (SourceOp op, ArrayRef<ArrayRef<Value>> operands,
720+ ConversionPatternRewriter &rewriter) const {
721+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
722+ }
671723 virtual LogicalResult
672724 matchAndRewrite (SourceOp op, ArrayRef<Value> operands,
673725 ConversionPatternRewriter &rewriter) const {
@@ -676,6 +728,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676728 rewrite (op, operands, rewriter);
677729 return success ();
678730 }
731+ virtual LogicalResult
732+ matchAndRewrite (SourceOp op, ArrayRef<ArrayRef<Value>> operands,
733+ ConversionPatternRewriter &rewriter) const {
734+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
735+ }
679736
680737private:
681738 using ConversionPattern::matchAndRewrite;
0 commit comments