Skip to content

Commit 53f216b

Browse files
committed
Feedback
1 parent 7eca1d8 commit 53f216b

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ using namespace mlir::vector;
3131
static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
3232
int step = 1) {
3333
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34+
assert(indices[dim] < vecType.getDimSize(dim) &&
35+
"Indices are out of bound");
3436
indices[dim] += step;
3537
if (indices[dim] < vecType.getDimSize(dim))
3638
break;
@@ -68,8 +70,8 @@ class ShapeCastOpNDDownCastRewritePattern
6870
numElts *= sourceVectorType.getDimSize(dim);
6971

7072
auto loc = op.getLoc();
71-
SmallVector<int64_t> srcIdx(srcRank - 1);
72-
SmallVector<int64_t> resIdx(resRank);
73+
SmallVector<int64_t> srcIdx(srcRank - 1, 0);
74+
SmallVector<int64_t> resIdx(resRank, 0);
7375
int64_t extractSize = sourceVectorType.getShape().back();
7476
Value result = rewriter.create<arith::ConstantOp>(
7577
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
@@ -124,8 +126,8 @@ class ShapeCastOpNDUpCastRewritePattern
124126
// Compute the indices of each 1-D vector element of the source slice
125127
// extraction and destination insertion and generate such instructions.
126128
auto loc = op.getLoc();
127-
SmallVector<int64_t> srcIdx(srcRank);
128-
SmallVector<int64_t> resIdx(resRank - 1);
129+
SmallVector<int64_t> srcIdx(srcRank, 0);
130+
SmallVector<int64_t> resIdx(resRank - 1, 0);
129131
int64_t extractSize = resultVectorType.getShape().back();
130132
Value result = rewriter.create<arith::ConstantOp>(
131133
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
@@ -180,8 +182,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
180182
// x[0,1,0] = y[0,2]
181183
// etc., incrementing the two index vectors "row-major"
182184
// within the source and result shape.
183-
SmallVector<int64_t> srcIdx(srcRank);
184-
SmallVector<int64_t> resIdx(resRank);
185+
SmallVector<int64_t> srcIdx(srcRank, 0);
186+
SmallVector<int64_t> resIdx(resRank, 0);
185187
Value result = rewriter.create<arith::ConstantOp>(
186188
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
187189
for (int64_t i = 0; i < numElts; i++) {
@@ -292,8 +294,8 @@ class ScalableShapeCastOpRewritePattern
292294
Value result = rewriter.create<arith::ConstantOp>(
293295
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
294296

295-
SmallVector<int64_t> srcIdx(srcRank);
296-
SmallVector<int64_t> resIdx(resRank);
297+
SmallVector<int64_t> srcIdx(srcRank, 0);
298+
SmallVector<int64_t> resIdx(resRank, 0);
297299

298300
// TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
299301
// once D150000 lands.

0 commit comments

Comments
 (0)