@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5454 return slicedIndices;
5555}
5656
57+ // compute the new indices for vector.load/store by adding offsets to
58+ // originalIndices.
59+ // It assumes m <= n (m = offsets.size(), n = originalIndices.size())
60+ // Last m of originalIndices will be updated.
61+ static SmallVector<Value> computeIndices (PatternRewriter &rewriter,
62+ Location loc,
63+ ArrayRef<Value> originalIndices,
64+ ArrayRef<int64_t > offsets) {
65+ assert (offsets.size () <= originalIndices.size () &&
66+ " Offsets should not exceed the number of original indices" );
67+ SmallVector<Value> indices (originalIndices);
68+ auto originalIter = originalIndices.rbegin ();
69+ auto offsetsIter = offsets.rbegin ();
70+ auto indicesIter = indices.rbegin ();
71+ while (offsetsIter != offsets.rend ()) {
72+ Value original = *originalIter;
73+ int64_t offset = *offsetsIter;
74+ if (offset != 0 )
75+ *indicesIter = rewriter.create <arith::AddIOp>(
76+ loc, original, rewriter.create <arith::ConstantIndexOp>(loc, offset));
77+ originalIter++;
78+ offsetsIter++;
79+ indicesIter++;
80+ }
81+ return indices;
82+ };
83+
5784// Clones `op` into a new operations that takes `operands` and returns
5885// `resultTypes`.
5986static Operation *cloneOpWithOperandsAndTypes (OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
631658 vector::UnrollVectorOptions options;
632659};
633660
661+ struct UnrollLoadPattern : public OpRewritePattern <vector::LoadOp> {
662+ UnrollLoadPattern (MLIRContext *context,
663+ const vector::UnrollVectorOptions &options,
664+ PatternBenefit benefit = 1 )
665+ : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
666+
667+ LogicalResult matchAndRewrite (vector::LoadOp loadOp,
668+ PatternRewriter &rewriter) const override {
669+ VectorType vecType = loadOp.getVectorType ();
670+ // Only unroll >1D loads
671+ if (vecType.getRank () <= 1 )
672+ return failure ();
673+
674+ Location loc = loadOp.getLoc ();
675+ ArrayRef<int64_t > originalShape = vecType.getShape ();
676+
677+ // Target type is a 1D vector of the innermost dimension.
678+ auto targetType =
679+ VectorType::get (originalShape.back (), vecType.getElementType ());
680+
681+ // Extend the targetShape to the same rank of original shape by padding 1s
682+ // for leading dimensions for convenience of computing offsets
683+ SmallVector<int64_t > targetShape (originalShape.size (), 1 );
684+ targetShape.back () = originalShape.back ();
685+
686+ Value result = rewriter.create <arith::ConstantOp>(
687+ loc, vecType, rewriter.getZeroAttr (vecType));
688+
689+ SmallVector<Value> originalIndices (loadOp.getIndices ().begin (),
690+ loadOp.getIndices ().end ());
691+
692+ for (SmallVector<int64_t > offsets :
693+ StaticTileOffsetRange (originalShape, targetShape)) {
694+ SmallVector<Value> indices =
695+ computeIndices (rewriter, loc, originalIndices, offsets);
696+ Value slice = rewriter.create <vector::LoadOp>(loc, targetType,
697+ loadOp.getBase (), indices);
698+ // Insert the slice into the result at the correct position.
699+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
700+ loc, slice, result, offsets, SmallVector<int64_t >({1 }));
701+ }
702+ rewriter.replaceOp (loadOp, result);
703+ return success ();
704+ }
705+
706+ private:
707+ vector::UnrollVectorOptions options;
708+ };
709+
710+ struct UnrollStorePattern : public OpRewritePattern <vector::StoreOp> {
711+ UnrollStorePattern (MLIRContext *context,
712+ const vector::UnrollVectorOptions &options,
713+ PatternBenefit benefit = 1 )
714+ : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
715+
716+ LogicalResult matchAndRewrite (vector::StoreOp storeOp,
717+ PatternRewriter &rewriter) const override {
718+ VectorType vecType = storeOp.getVectorType ();
719+ // Only unroll >1D stores.
720+ if (vecType.getRank () <= 1 )
721+ return failure ();
722+
723+ Location loc = storeOp.getLoc ();
724+ ArrayRef<int64_t > originalShape = vecType.getShape ();
725+
726+ // Extend the targetShape to the same rank of original shape by padding 1s
727+ // for leading dimensions for convenience of computing offsets
728+ SmallVector<int64_t > targetShape (originalShape.size (), 1 );
729+ targetShape.back () = originalShape.back ();
730+
731+ Value base = storeOp.getBase ();
732+ Value vector = storeOp.getValueToStore ();
733+
734+ SmallVector<Value> originalIndices (storeOp.getIndices ().begin (),
735+ storeOp.getIndices ().end ());
736+
737+ for (SmallVector<int64_t > offsets :
738+ StaticTileOffsetRange (originalShape, targetShape)) {
739+ SmallVector<Value> indices =
740+ computeIndices (rewriter, loc, originalIndices, offsets);
741+ offsets.pop_back ();
742+ Value slice = rewriter.create <vector::ExtractOp>(loc, vector, offsets);
743+ rewriter.create <vector::StoreOp>(loc, slice, base, indices);
744+ }
745+ rewriter.eraseOp (storeOp);
746+ return success ();
747+ }
748+
749+ private:
750+ vector::UnrollVectorOptions options;
751+ };
752+
634753} // namespace
635754
636755void mlir::vector::populateVectorUnrollPatterns (
@@ -639,6 +758,6 @@ void mlir::vector::populateVectorUnrollPatterns(
639758 patterns.add <UnrollTransferReadPattern, UnrollTransferWritePattern,
640759 UnrollContractionPattern, UnrollElementwisePattern,
641760 UnrollReductionPattern, UnrollMultiReductionPattern,
642- UnrollTransposePattern, UnrollGatherPattern>(
643- patterns.getContext (), options, benefit);
761+ UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
762+ UnrollStorePattern>( patterns.getContext (), options, benefit);
644763}
0 commit comments