@@ -775,23 +775,26 @@ class FlattenContiguousRowMajorTransferWritePattern
775775 unsigned targetVectorBitwidth;
776776};
777777
778- // / Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
779- // / to `memref.load` patterns. The `match` method is shared for both
780- // / `vector.extract` and `vector.extract_element`.
781- template <class VectorExtractOp >
782- class RewriteScalarExtractOfTransferReadBase
783- : public OpRewritePattern<VectorExtractOp> {
784- using Base = OpRewritePattern<VectorExtractOp>;
785-
778+ // / Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
779+ // /
780+ // / All the users of the transfer op must be `vector.extract` ops. If
781+ // / `allowMultipleUses` is set to true, rewrite transfer ops with any number of
782+ // / users. Otherwise, rewrite only if the extract op is the single user of the
783+ // / transfer op. Rewriting a single vector load with multiple scalar loads may
784+ // / negatively affect performance.
785+ class RewriteScalarExtractOfTransferRead
786+ : public OpRewritePattern<vector::ExtractOp> {
786787public:
787- RewriteScalarExtractOfTransferReadBase (MLIRContext *context,
788- PatternBenefit benefit,
789- bool allowMultipleUses)
790- : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
791-
792- LogicalResult match (VectorExtractOp extractOp) const {
793- auto xferOp =
794- extractOp.getVector ().template getDefiningOp <vector::TransferReadOp>();
788+ RewriteScalarExtractOfTransferRead (MLIRContext *context,
789+ PatternBenefit benefit,
790+ bool allowMultipleUses)
791+ : OpRewritePattern(context, benefit),
792+ allowMultipleUses (allowMultipleUses) {}
793+
794+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
795+ PatternRewriter &rewriter) const override {
796+ // Match phase.
797+ auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
795798 if (!xferOp)
796799 return failure ();
797800 // Check that we are extracting a scalar and not a sub-vector.
@@ -803,8 +806,7 @@ class RewriteScalarExtractOfTransferReadBase
803806 // If multiple uses are allowed, check if all the xfer uses are extract ops.
804807 if (allowMultipleUses &&
805808 !llvm::all_of (xferOp->getUses (), [](OpOperand &use) {
806- return isa<vector::ExtractOp, vector::ExtractElementOp>(
807- use.getOwner ());
809+ return isa<vector::ExtractOp>(use.getOwner ());
808810 }))
809811 return failure ();
810812 // Mask not supported.
@@ -816,81 +818,8 @@ class RewriteScalarExtractOfTransferReadBase
816818 // Cannot rewrite if the indices may be out of bounds.
817819 if (xferOp.hasOutOfBoundsDim ())
818820 return failure ();
819- return success ();
820- }
821-
822- private:
823- bool allowMultipleUses;
824- };
825-
826- // / Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
827- // /
828- // / All the users of the transfer op must be either `vector.extractelement` or
829- // / `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
830- // / transfer ops with any number of users. Otherwise, rewrite only if the
831- // / extract op is the single user of the transfer op. Rewriting a single
832- // / vector load with multiple scalar loads may negatively affect performance.
833- class RewriteScalarExtractElementOfTransferRead
834- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
835- using RewriteScalarExtractOfTransferReadBase::
836- RewriteScalarExtractOfTransferReadBase;
837-
838- LogicalResult matchAndRewrite (vector::ExtractElementOp extractOp,
839- PatternRewriter &rewriter) const override {
840- if (failed (match (extractOp)))
841- return failure ();
842-
843- // Construct scalar load.
844- auto loc = extractOp.getLoc ();
845- auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
846- SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
847- xferOp.getIndices ().end ());
848- if (extractOp.getPosition ()) {
849- AffineExpr sym0, sym1;
850- bindSymbols (extractOp.getContext (), sym0, sym1);
851- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
852- rewriter, loc, sym0 + sym1,
853- {newIndices[newIndices.size () - 1 ], extractOp.getPosition ()});
854- if (auto value = dyn_cast<Value>(ofr)) {
855- newIndices[newIndices.size () - 1 ] = value;
856- } else {
857- newIndices[newIndices.size () - 1 ] =
858- rewriter.create <arith::ConstantIndexOp>(loc,
859- *getConstantIntValue (ofr));
860- }
861- }
862- if (isa<MemRefType>(xferOp.getBase ().getType ())) {
863- rewriter.replaceOpWithNewOp <memref::LoadOp>(extractOp, xferOp.getBase (),
864- newIndices);
865- } else {
866- rewriter.replaceOpWithNewOp <tensor::ExtractOp>(
867- extractOp, xferOp.getBase (), newIndices);
868- }
869-
870- return success ();
871- }
872- };
873-
874- // / Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
875- // / Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
876- // /
877- // / All the users of the transfer op must be either `vector.extractelement` or
878- // / `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
879- // / transfer ops with any number of users. Otherwise, rewrite only if the
880- // / extract op is the single user of the transfer op. Rewriting a single
881- // / vector load with multiple scalar loads may negatively affect performance.
882- class RewriteScalarExtractOfTransferRead
883- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
884- using RewriteScalarExtractOfTransferReadBase::
885- RewriteScalarExtractOfTransferReadBase;
886-
887- LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
888- PatternRewriter &rewriter) const override {
889- if (failed (match (extractOp)))
890- return failure ();
891821
892- // Construct scalar load.
893- auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
822+ // Rewrite phase: construct scalar load.
894823 SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
895824 xferOp.getIndices ().end ());
896825 for (auto [i, pos] : llvm::enumerate (extractOp.getMixedPosition ())) {
@@ -931,6 +860,9 @@ class RewriteScalarExtractOfTransferRead
931860
932861 return success ();
933862 }
863+
864+ private:
865+ bool allowMultipleUses;
934866};
935867
936868// / Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
@@ -987,8 +919,7 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
987919void mlir::vector::populateScalarVectorTransferLoweringPatterns (
988920 RewritePatternSet &patterns, PatternBenefit benefit,
989921 bool allowMultipleUses) {
990- patterns.add <RewriteScalarExtractElementOfTransferRead,
991- RewriteScalarExtractOfTransferRead>(patterns.getContext (),
922+ patterns.add <RewriteScalarExtractOfTransferRead>(patterns.getContext (),
992923 benefit, allowMultipleUses);
993924 patterns.add <RewriteScalarWrite>(patterns.getContext (), benefit);
994925}
0 commit comments