@@ -31,6 +31,8 @@ using namespace mlir::vector;
3131static void incIdx (SmallVectorImpl<int64_t > &indices, VectorType vecType,
3232 int step = 1 ) {
3333 for (int dim : llvm::reverse (llvm::seq<int >(0 , indices.size ()))) {
34+ assert (indices[dim] < vecType.getDimSize (dim) &&
35+ " Indices are out of bound" );
3436 indices[dim] += step;
3537 if (indices[dim] < vecType.getDimSize (dim))
3638 break ;
@@ -68,8 +70,8 @@ class ShapeCastOpNDDownCastRewritePattern
6870 numElts *= sourceVectorType.getDimSize (dim);
6971
7072 auto loc = op.getLoc ();
71- SmallVector<int64_t > srcIdx (srcRank - 1 );
72- SmallVector<int64_t > resIdx (resRank);
73+ SmallVector<int64_t > srcIdx (srcRank - 1 , 0 );
74+ SmallVector<int64_t > resIdx (resRank, 0 );
7375 int64_t extractSize = sourceVectorType.getShape ().back ();
7476 Value result = rewriter.create <arith::ConstantOp>(
7577 loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
@@ -124,8 +126,8 @@ class ShapeCastOpNDUpCastRewritePattern
124126 // Compute the indices of each 1-D vector element of the source slice
125127 // extraction and destination insertion and generate such instructions.
126128 auto loc = op.getLoc ();
127- SmallVector<int64_t > srcIdx (srcRank);
128- SmallVector<int64_t > resIdx (resRank - 1 );
129+ SmallVector<int64_t > srcIdx (srcRank, 0 );
130+ SmallVector<int64_t > resIdx (resRank - 1 , 0 );
129131 int64_t extractSize = resultVectorType.getShape ().back ();
130132 Value result = rewriter.create <arith::ConstantOp>(
131133 loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
@@ -180,8 +182,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
180182 // x[0,1,0] = y[0,2]
181183 // etc., incrementing the two index vectors "row-major"
182184 // within the source and result shape.
183- SmallVector<int64_t > srcIdx (srcRank);
184- SmallVector<int64_t > resIdx (resRank);
185+ SmallVector<int64_t > srcIdx (srcRank, 0 );
186+ SmallVector<int64_t > resIdx (resRank, 0 );
185187 Value result = rewriter.create <arith::ConstantOp>(
186188 loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
187189 for (int64_t i = 0 ; i < numElts; i++) {
@@ -292,8 +294,8 @@ class ScalableShapeCastOpRewritePattern
292294 Value result = rewriter.create <arith::ConstantOp>(
293295 loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
294296
295- SmallVector<int64_t > srcIdx (srcRank);
296- SmallVector<int64_t > resIdx (resRank);
297+ SmallVector<int64_t > srcIdx (srcRank, 0 );
298+ SmallVector<int64_t > resIdx (resRank, 0 );
297299
298300 // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
299301 // once D150000 lands.
0 commit comments