@@ -57,10 +57,10 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5757// Compute the new indices by adding `offsets` to `originalIndices`.
5858// If m < n (m = offsets.size(), n = originalIndices.size()),
5959// then only the trailing m values in `originalIndices` are updated.
60- static SmallVector<Value> computeIndices (PatternRewriter &rewriter,
61- Location loc,
62- ArrayRef<Value> originalIndices,
63- ArrayRef<int64_t > offsets) {
60+ static SmallVector<Value> sliceLoadStoreIndices (PatternRewriter &rewriter,
61+ Location loc,
62+ OperandRange originalIndices,
63+ ArrayRef<int64_t > offsets) {
6464 assert (offsets.size () <= originalIndices.size () &&
6565 " Offsets should not exceed the number of original indices" );
6666 SmallVector<Value> indices (originalIndices);
@@ -662,8 +662,6 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
662662 LogicalResult matchAndRewrite (vector::LoadOp loadOp,
663663 PatternRewriter &rewriter) const override {
664664 VectorType vecType = loadOp.getVectorType ();
665- if (vecType.getRank () <= 1 )
666- return failure ();
667665
668666 auto targetShape = getTargetShape (options, loadOp);
669667 if (!targetShape)
@@ -676,9 +674,6 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
676674 Value result = rewriter.create <arith::ConstantOp>(
677675 loc, vecType, rewriter.getZeroAttr (vecType));
678676
679- SmallVector<Value> originalIndices (loadOp.getIndices ().begin (),
680- loadOp.getIndices ().end ());
681-
682677 SmallVector<int64_t > loopOrder =
683678 getUnrollOrder (originalShape.size (), loadOp, options);
684679
@@ -688,11 +683,11 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
688683 for (SmallVector<int64_t > offsets :
689684 StaticTileOffsetRange (originalShape, *targetShape, loopOrder)) {
690685 SmallVector<Value> indices =
691- computeIndices (rewriter, loc, originalIndices , offsets);
692- Value slice = rewriter.create <vector::LoadOp>(loc, targetVecType,
693- loadOp.getBase (), indices);
686+ sliceLoadStoreIndices (rewriter, loc, loadOp. getIndices () , offsets);
687+ Value slicedLoad = rewriter.create <vector::LoadOp>(
688+ loc, targetVecType, loadOp.getBase (), indices);
694689 result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
695- loc, slice , result, offsets, strides);
690+ loc, slicedLoad , result, offsets, strides);
696691 }
697692 rewriter.replaceOp (loadOp, result);
698693 return success ();
@@ -711,8 +706,6 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
711706 LogicalResult matchAndRewrite (vector::StoreOp storeOp,
712707 PatternRewriter &rewriter) const override {
713708 VectorType vecType = storeOp.getVectorType ();
714- if (vecType.getRank () <= 1 )
715- return failure ();
716709
717710 auto targetShape = getTargetShape (options, storeOp);
718711 if (!targetShape)
@@ -725,16 +718,13 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
725718 Value base = storeOp.getBase ();
726719 Value vector = storeOp.getValueToStore ();
727720
728- SmallVector<Value> originalIndices (storeOp.getIndices ().begin (),
729- storeOp.getIndices ().end ());
730-
731721 SmallVector<int64_t > loopOrder =
732722 getUnrollOrder (originalShape.size (), storeOp, options);
733723
734724 for (SmallVector<int64_t > offsets :
735725 StaticTileOffsetRange (originalShape, *targetShape, loopOrder)) {
736726 SmallVector<Value> indices =
737- computeIndices (rewriter, loc, originalIndices , offsets);
727+ sliceLoadStoreIndices (rewriter, loc, storeOp. getIndices () , offsets);
738728 Value slice = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
739729 loc, vector, offsets, *targetShape, strides);
740730 rewriter.create <vector::StoreOp>(loc, slice, base, indices);
0 commit comments