@@ -54,6 +54,28 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5454 return slicedIndices;
5555}
5656
57+ // Compute the new indices by adding `offsets` to `originalIndices`.
58+ // If m < n (m = offsets.size(), n = originalIndices.size()),
59+ // then only the trailing m values in `originalIndices` are updated.
60+ static SmallVector<Value> sliceLoadStoreIndices (PatternRewriter &rewriter,
61+ Location loc,
62+ OperandRange originalIndices,
63+ ArrayRef<int64_t > offsets) {
64+ assert (offsets.size () <= originalIndices.size () &&
65+ " Offsets should not exceed the number of original indices" );
66+ SmallVector<Value> indices (originalIndices);
67+
68+ auto start = indices.size () - offsets.size ();
69+ for (auto [i, offset] : llvm::enumerate (offsets)) {
70+ if (offset != 0 ) {
71+ indices[start + i] = rewriter.create <arith::AddIOp>(
72+ loc, originalIndices[start + i],
73+ rewriter.create <arith::ConstantIndexOp>(loc, offset));
74+ }
75+ }
76+ return indices;
77+ }
78+
5779// Clones `op` into a new operations that takes `operands` and returns
5880// `resultTypes`.
5981static Operation *cloneOpWithOperandsAndTypes (OpBuilder &builder, Location loc,
@@ -631,6 +653,90 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
631653 vector::UnrollVectorOptions options;
632654};
633655
656+ struct UnrollLoadPattern : public OpRewritePattern <vector::LoadOp> {
657+ UnrollLoadPattern (MLIRContext *context,
658+ const vector::UnrollVectorOptions &options,
659+ PatternBenefit benefit = 1 )
660+ : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
661+
662+ LogicalResult matchAndRewrite (vector::LoadOp loadOp,
663+ PatternRewriter &rewriter) const override {
664+ VectorType vecType = loadOp.getVectorType ();
665+
666+ auto targetShape = getTargetShape (options, loadOp);
667+ if (!targetShape)
668+ return failure ();
669+
670+ Location loc = loadOp.getLoc ();
671+ ArrayRef<int64_t > originalShape = vecType.getShape ();
672+ SmallVector<int64_t > strides (targetShape->size (), 1 );
673+
674+ Value result = rewriter.create <arith::ConstantOp>(
675+ loc, vecType, rewriter.getZeroAttr (vecType));
676+
677+ SmallVector<int64_t > loopOrder =
678+ getUnrollOrder (originalShape.size (), loadOp, options);
679+
680+ auto targetVecType =
681+ VectorType::get (*targetShape, vecType.getElementType ());
682+
683+ for (SmallVector<int64_t > offsets :
684+ StaticTileOffsetRange (originalShape, *targetShape, loopOrder)) {
685+ SmallVector<Value> indices =
686+ sliceLoadStoreIndices (rewriter, loc, loadOp.getIndices (), offsets);
687+ Value slicedLoad = rewriter.create <vector::LoadOp>(
688+ loc, targetVecType, loadOp.getBase (), indices);
689+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
690+ loc, slicedLoad, result, offsets, strides);
691+ }
692+ rewriter.replaceOp (loadOp, result);
693+ return success ();
694+ }
695+
696+ private:
697+ vector::UnrollVectorOptions options;
698+ };
699+
700+ struct UnrollStorePattern : public OpRewritePattern <vector::StoreOp> {
701+ UnrollStorePattern (MLIRContext *context,
702+ const vector::UnrollVectorOptions &options,
703+ PatternBenefit benefit = 1 )
704+ : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
705+
706+ LogicalResult matchAndRewrite (vector::StoreOp storeOp,
707+ PatternRewriter &rewriter) const override {
708+ VectorType vecType = storeOp.getVectorType ();
709+
710+ auto targetShape = getTargetShape (options, storeOp);
711+ if (!targetShape)
712+ return failure ();
713+
714+ Location loc = storeOp.getLoc ();
715+ ArrayRef<int64_t > originalShape = vecType.getShape ();
716+ SmallVector<int64_t > strides (targetShape->size (), 1 );
717+
718+ Value base = storeOp.getBase ();
719+ Value vector = storeOp.getValueToStore ();
720+
721+ SmallVector<int64_t > loopOrder =
722+ getUnrollOrder (originalShape.size (), storeOp, options);
723+
724+ for (SmallVector<int64_t > offsets :
725+ StaticTileOffsetRange (originalShape, *targetShape, loopOrder)) {
726+ SmallVector<Value> indices =
727+ sliceLoadStoreIndices (rewriter, loc, storeOp.getIndices (), offsets);
728+ Value slice = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
729+ loc, vector, offsets, *targetShape, strides);
730+ rewriter.create <vector::StoreOp>(loc, slice, base, indices);
731+ }
732+ rewriter.eraseOp (storeOp);
733+ return success ();
734+ }
735+
736+ private:
737+ vector::UnrollVectorOptions options;
738+ };
739+
634740struct UnrollBroadcastPattern : public OpRewritePattern <vector::BroadcastOp> {
635741 UnrollBroadcastPattern (MLIRContext *context,
636742 const vector::UnrollVectorOptions &options,
@@ -699,10 +805,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
699805void mlir::vector::populateVectorUnrollPatterns (
700806 RewritePatternSet &patterns, const UnrollVectorOptions &options,
701807 PatternBenefit benefit) {
702- patterns
703- . add <UnrollTransferReadPattern, UnrollTransferWritePattern ,
704- UnrollContractionPattern, UnrollElementwisePattern ,
705- UnrollReductionPattern, UnrollMultiReductionPattern ,
706- UnrollTransposePattern, UnrollGatherPattern , UnrollBroadcastPattern>(
707- patterns.getContext (), options, benefit);
808+ patterns. add <UnrollTransferReadPattern, UnrollTransferWritePattern,
809+ UnrollContractionPattern, UnrollElementwisePattern ,
810+ UnrollReductionPattern, UnrollMultiReductionPattern ,
811+ UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern ,
812+ UnrollStorePattern , UnrollBroadcastPattern>(
813+ patterns.getContext (), options, benefit);
708814}
0 commit comments