Skip to content

Commit 2916932

Browse files
committed
Address comments
1 parent 2731a18 commit 2916932

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5757
// Compute the new indices by adding `offsets` to `originalIndices`.
5858
// If m < n (m = offsets.size(), n = originalIndices.size()),
5959
// then only the trailing m values in `originalIndices` are updated.
60-
static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
61-
Location loc,
62-
ArrayRef<Value> originalIndices,
63-
ArrayRef<int64_t> offsets) {
60+
static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter,
61+
Location loc,
62+
OperandRange originalIndices,
63+
ArrayRef<int64_t> offsets) {
6464
assert(offsets.size() <= originalIndices.size() &&
6565
"Offsets should not exceed the number of original indices");
6666
SmallVector<Value> indices(originalIndices);
@@ -662,8 +662,6 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
662662
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
663663
PatternRewriter &rewriter) const override {
664664
VectorType vecType = loadOp.getVectorType();
665-
if (vecType.getRank() <= 1)
666-
return failure();
667665

668666
auto targetShape = getTargetShape(options, loadOp);
669667
if (!targetShape)
@@ -676,9 +674,6 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
676674
Value result = rewriter.create<arith::ConstantOp>(
677675
loc, vecType, rewriter.getZeroAttr(vecType));
678676

679-
SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
680-
loadOp.getIndices().end());
681-
682677
SmallVector<int64_t> loopOrder =
683678
getUnrollOrder(originalShape.size(), loadOp, options);
684679

@@ -688,11 +683,11 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
688683
for (SmallVector<int64_t> offsets :
689684
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
690685
SmallVector<Value> indices =
691-
computeIndices(rewriter, loc, originalIndices, offsets);
692-
Value slice = rewriter.create<vector::LoadOp>(loc, targetVecType,
693-
loadOp.getBase(), indices);
686+
sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
687+
Value slicedLoad = rewriter.create<vector::LoadOp>(
688+
loc, targetVecType, loadOp.getBase(), indices);
694689
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
695-
loc, slice, result, offsets, strides);
690+
loc, slicedLoad, result, offsets, strides);
696691
}
697692
rewriter.replaceOp(loadOp, result);
698693
return success();
@@ -711,8 +706,6 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
711706
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
712707
PatternRewriter &rewriter) const override {
713708
VectorType vecType = storeOp.getVectorType();
714-
if (vecType.getRank() <= 1)
715-
return failure();
716709

717710
auto targetShape = getTargetShape(options, storeOp);
718711
if (!targetShape)
@@ -725,16 +718,13 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
725718
Value base = storeOp.getBase();
726719
Value vector = storeOp.getValueToStore();
727720

728-
SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
729-
storeOp.getIndices().end());
730-
731721
SmallVector<int64_t> loopOrder =
732722
getUnrollOrder(originalShape.size(), storeOp, options);
733723

734724
for (SmallVector<int64_t> offsets :
735725
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
736726
SmallVector<Value> indices =
737-
computeIndices(rewriter, loc, originalIndices, offsets);
727+
sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
738728
Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
739729
loc, vector, offsets, *targetShape, strides);
740730
rewriter.create<vector::StoreOp>(loc, slice, base, indices);

0 commit comments

Comments
 (0)