@@ -28,17 +28,20 @@ using namespace mlir;
2828using namespace mlir ::vector;
2929
3030// / Increments n-D `indices` by `step` starting from the innermost dimension.
31- static void incIdx (SmallVectorImpl <int64_t > & indices, VectorType vecType ,
31+ static void incIdx (MutableArrayRef <int64_t > indices, ArrayRef< int64_t > shape ,
3232 int step = 1 ) {
3333 for (int dim : llvm::reverse (llvm::seq<int >(0 , indices.size ()))) {
34- assert (indices[dim] < vecType.getDimSize (dim) &&
35- " Indices are out of bound" );
34+ int64_t dimSize = shape[dim];
35+ assert (indices[dim] < dimSize && " Indices are out of bound" );
36+
3637 indices[dim] += step;
37- if (indices[dim] < vecType.getDimSize (dim))
38+
39+ int64_t spill = indices[dim] / dimSize;
40+ if (spill == 0 )
3841 break ;
3942
40- indices[dim] = 0 ;
41- step = 1 ;
43+ indices[dim] %= dimSize ;
44+ step = spill ;
4245 }
4346}
4447
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
7982 // and destination slice insertion and generate such instructions.
8083 for (int64_t i = 0 ; i < numElts; ++i) {
8184 if (i != 0 ) {
82- incIdx (srcIdx, sourceVectorType, /* step=*/ 1 );
83- incIdx (resIdx, resultVectorType, /* step=*/ extractSize);
85+ incIdx (srcIdx, sourceVectorType. getShape () , /* step=*/ 1 );
86+ incIdx (resIdx, resultVectorType. getShape () , /* step=*/ extractSize);
8487 }
8588
8689 Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
131134 Value result = rewriter.create <ub::PoisonOp>(loc, resultVectorType);
132135 for (int64_t i = 0 ; i < numElts; ++i) {
133136 if (i != 0 ) {
134- incIdx (srcIdx, sourceVectorType, /* step=*/ extractSize);
135- incIdx (resIdx, resultVectorType, /* step=*/ 1 );
137+ incIdx (srcIdx, sourceVectorType. getShape () , /* step=*/ extractSize);
138+ incIdx (resIdx, resultVectorType. getShape () , /* step=*/ 1 );
136139 }
137140
138141 Value extract = rewriter.create <vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
157160 LogicalResult matchAndRewrite (vector::ShapeCastOp op,
158161 PatternRewriter &rewriter) const override {
159162 Location loc = op.getLoc ();
160- auto sourceVectorType = op.getSourceVectorType ();
161- auto resultVectorType = op.getResultVectorType ();
163+ VectorType sourceType = op.getSourceVectorType ();
164+ VectorType resultType = op.getResultVectorType ();
162165
163- if (sourceVectorType .isScalable () || resultVectorType .isScalable ())
166+ if (sourceType .isScalable () || resultType .isScalable ())
164167 return failure ();
165168
166- // Special case for n-D / 1-D lowerings with better implementations.
167- int64_t srcRank = sourceVectorType.getRank ();
168- int64_t resRank = resultVectorType.getRank ();
169- if ((srcRank > 1 && resRank == 1 ) || (srcRank == 1 && resRank > 1 ))
169+ // Special case for n-D / 1-D lowerings with implementations that use
170+ // extract_strided_slice / insert_strided_slice.
171+ int64_t sourceRank = sourceType.getRank ();
172+ int64_t resultRank = resultType.getRank ();
173+ if ((sourceRank > 1 && resultRank == 1 ) ||
174+ (sourceRank == 1 && resultRank > 1 ))
170175 return failure ();
171176
172- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
173- // extract/insert chains.
174- int64_t numElts = 1 ;
175- for (int64_t r = 0 ; r < srcRank; r++)
176- numElts *= sourceVectorType.getDimSize (r);
177+ int64_t numExtracts = sourceType.getNumElements ();
178+ int64_t nbCommonInnerDims = 0 ;
179+ while (true ) {
180+ int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
181+ int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
182+ if (sourceDim < 0 || resultDim < 0 )
183+ break ;
184+ int64_t dimSize = sourceType.getDimSize (sourceDim);
185+ if (dimSize != resultType.getDimSize (resultDim))
186+ break ;
187+ numExtracts /= dimSize;
188+ ++nbCommonInnerDims;
189+ }
190+
177191 // Replace with data movement operations:
178192 // x[0,0,0] = y[0,0]
179193 // x[0,0,1] = y[0,1]
180194 // x[0,1,0] = y[0,2]
181195 // etc., incrementing the two index vectors "row-major"
182196 // within the source and result shape.
183- SmallVector<int64_t > srcIdx (srcRank, 0 );
184- SmallVector<int64_t > resIdx (resRank, 0 );
185- Value result = rewriter.create <ub::PoisonOp>(loc, resultVectorType);
186- for (int64_t i = 0 ; i < numElts; i++) {
197+ SmallVector<int64_t > sourceIndex (sourceRank - nbCommonInnerDims, 0 );
198+ SmallVector<int64_t > resultIndex (resultRank - nbCommonInnerDims, 0 );
199+ Value result = rewriter.create <ub::PoisonOp>(loc, resultType);
200+
201+ for (int64_t i = 0 ; i < numExtracts; i++) {
187202 if (i != 0 ) {
188- incIdx (srcIdx, sourceVectorType );
189- incIdx (resIdx, resultVectorType );
203+ incIdx (sourceIndex, sourceType. getShape (). drop_back (nbCommonInnerDims) );
204+ incIdx (resultIndex, resultType. getShape (). drop_back (nbCommonInnerDims) );
190205 }
191206
192207 Value extract =
193- rewriter.create <vector::ExtractOp>(loc, op.getSource (), srcIdx);
194- result = rewriter.create <vector::InsertOp>(loc, extract, result, resIdx);
208+ rewriter.create <vector::ExtractOp>(loc, op.getSource (), sourceIndex);
209+ result =
210+ rewriter.create <vector::InsertOp>(loc, extract, result, resultIndex);
195211 }
196212 rewriter.replaceOp (op, result);
197213 return success ();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
329345
330346 // 4. Increment the insert/extract indices, stepping by minExtractionSize
331347 // for the trailing dimensions.
332- incIdx (srcIdx, sourceVectorType, /* step=*/ minExtractionSize);
333- incIdx (resIdx, resultVectorType, /* step=*/ minExtractionSize);
348+ incIdx (srcIdx, sourceVectorType. getShape () , /* step=*/ minExtractionSize);
349+ incIdx (resIdx, resultVectorType. getShape () , /* step=*/ minExtractionSize);
334350 }
335351
336352 rewriter.replaceOp (op, result);
0 commit comments