@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
537537 ConversionPatternRewriter &rewriter) const {
538538 llvm_unreachable (" unimplemented rewrite" );
539539 }
540+ virtual void rewrite (Operation *op, ArrayRef<ValueRange> operands,
541+ ConversionPatternRewriter &rewriter) const {
542+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
543+ }
540544
541545 // / Hook for derived classes to implement combined matching and rewriting.
542546 virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
547551 rewrite (op, operands, rewriter);
548552 return success ();
549553 }
554+ virtual LogicalResult
555+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
556+ ConversionPatternRewriter &rewriter) const {
557+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
558+ }
550559
551560 // / Attempt to match and rewrite the IR root at the specified operation.
552561 LogicalResult matchAndRewrite (Operation *op,
@@ -574,6 +583,15 @@ class ConversionPattern : public RewritePattern {
574583 : RewritePattern(std::forward<Args>(args)...),
575584 typeConverter (&typeConverter) {}
576585
586+ // / Given an array of value ranges, which are the inputs to a 1:N adaptor,
587+ // / try to extract the single value of each range to construct a the inputs
588+ // / for a 1:1 adaptor.
589+ // /
590+ // / This function produces a fatal error if at least one range has 0 or
591+ // / more than 1 value: "pattern 'name' does not support 1:N conversion"
592+ SmallVector<Value>
593+ getOneToOneAdaptorOperands (ArrayRef<ValueRange> operands) const ;
594+
577595protected:
578596 // / An optional type converter for use by this pattern.
579597 const TypeConverter *typeConverter = nullptr ;
@@ -589,6 +607,8 @@ template <typename SourceOp>
589607class OpConversionPattern : public ConversionPattern {
590608public:
591609 using OpAdaptor = typename SourceOp::Adaptor;
610+ using OneToNOpAdaptor =
611+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
592612
593613 OpConversionPattern (MLIRContext *context, PatternBenefit benefit = 1 )
594614 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +627,24 @@ class OpConversionPattern : public ConversionPattern {
607627 auto sourceOp = cast<SourceOp>(op);
608628 rewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
609629 }
630+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
631+ ConversionPatternRewriter &rewriter) const final {
632+ auto sourceOp = cast<SourceOp>(op);
633+ rewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp), rewriter);
634+ }
610635 LogicalResult
611636 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
612637 ConversionPatternRewriter &rewriter) const final {
613638 auto sourceOp = cast<SourceOp>(op);
614639 return matchAndRewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
615640 }
641+ LogicalResult
642+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
643+ ConversionPatternRewriter &rewriter) const final {
644+ auto sourceOp = cast<SourceOp>(op);
645+ return matchAndRewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp),
646+ rewriter);
647+ }
616648
617649 // / Rewrite and Match methods that operate on the SourceOp type. These must be
618650 // / overridden by the derived pattern class.
@@ -623,6 +655,12 @@ class OpConversionPattern : public ConversionPattern {
623655 ConversionPatternRewriter &rewriter) const {
624656 llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
625657 }
658+ virtual void rewrite (SourceOp op, OneToNOpAdaptor adaptor,
659+ ConversionPatternRewriter &rewriter) const {
660+ SmallVector<Value> oneToOneOperands =
661+ getOneToOneAdaptorOperands (adaptor.getOperands ());
662+ rewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
663+ }
626664 virtual LogicalResult
627665 matchAndRewrite (SourceOp op, OpAdaptor adaptor,
628666 ConversionPatternRewriter &rewriter) const {
@@ -631,6 +669,13 @@ class OpConversionPattern : public ConversionPattern {
631669 rewrite (op, adaptor, rewriter);
632670 return success ();
633671 }
672+ virtual LogicalResult
673+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
674+ ConversionPatternRewriter &rewriter) const {
675+ SmallVector<Value> oneToOneOperands =
676+ getOneToOneAdaptorOperands (adaptor.getOperands ());
677+ return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
678+ }
634679
635680private:
636681 using ConversionPattern::matchAndRewrite;
@@ -656,18 +701,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656701 ConversionPatternRewriter &rewriter) const final {
657702 rewrite (cast<SourceOp>(op), operands, rewriter);
658703 }
704+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
705+ ConversionPatternRewriter &rewriter) const final {
706+ rewrite (cast<SourceOp>(op), operands, rewriter);
707+ }
659708 LogicalResult
660709 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
661710 ConversionPatternRewriter &rewriter) const final {
662711 return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
663712 }
713+ LogicalResult
714+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
715+ ConversionPatternRewriter &rewriter) const final {
716+ return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
717+ }
664718
665719 // / Rewrite and Match methods that operate on the SourceOp type. These must be
666720 // / overridden by the derived pattern class.
667721 virtual void rewrite (SourceOp op, ArrayRef<Value> operands,
668722 ConversionPatternRewriter &rewriter) const {
669723 llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
670724 }
725+ virtual void rewrite (SourceOp op, ArrayRef<ValueRange> operands,
726+ ConversionPatternRewriter &rewriter) const {
727+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
728+ }
671729 virtual LogicalResult
672730 matchAndRewrite (SourceOp op, ArrayRef<Value> operands,
673731 ConversionPatternRewriter &rewriter) const {
@@ -676,6 +734,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676734 rewrite (op, operands, rewriter);
677735 return success ();
678736 }
737+ virtual LogicalResult
738+ matchAndRewrite (SourceOp op, ArrayRef<ValueRange> operands,
739+ ConversionPatternRewriter &rewriter) const {
740+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
741+ }
679742
680743private:
681744 using ConversionPattern::matchAndRewrite;
0 commit comments