1818using namespace mlir ;
1919using namespace mlir ::vector;
2020
21- // Helper that picks the proper sequence for inserting.
22- static Value insertOne (PatternRewriter &rewriter, Location loc, Value from,
23- Value into, int64_t offset) {
24- auto vectorType = cast<VectorType>(into.getType ());
25- if (vectorType.getRank () > 1 )
26- return rewriter.create <InsertOp>(loc, from, into, offset);
27- return rewriter.create <vector::InsertElementOp>(
28- loc, vectorType, from, into,
29- rewriter.create <arith::ConstantIndexOp>(loc, offset));
30- }
31-
32- // Helper that picks the proper sequence for extracting.
33- static Value extractOne (PatternRewriter &rewriter, Location loc, Value vector,
34- int64_t offset) {
35- auto vectorType = cast<VectorType>(vector.getType ());
36- if (vectorType.getRank () > 1 )
37- return rewriter.create <ExtractOp>(loc, vector, offset);
38- return rewriter.create <vector::ExtractElementOp>(
39- loc, vectorType.getElementType (), vector,
40- rewriter.create <arith::ConstantIndexOp>(loc, offset));
41- }
42-
4321// / RewritePattern for InsertStridedSliceOp where source and destination vectors
4422// / have different ranks.
4523// /
@@ -173,11 +151,13 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
173151 for (int64_t off = offset, e = offset + size * stride, idx = 0 ; off < e;
174152 off += stride, ++idx) {
175153 // 1. extract the proper subvector (or element) from source
176- Value extractedSource = extractOne (rewriter, loc, op.getSource (), idx);
154+ Value extractedSource =
155+ rewriter.create <ExtractOp>(loc, op.getSource (), idx);
177156 if (isa<VectorType>(extractedSource.getType ())) {
178157 // 2. If we have a vector, extract the proper subvector from destination
179158 // Otherwise we are at the element level and no need to recurse.
180- Value extractedDest = extractOne (rewriter, loc, op.getDest (), off);
159+ Value extractedDest =
160+ rewriter.create <ExtractOp>(loc, op.getDest (), off);
181161 // 3. Reduce the problem to lowering a new InsertStridedSlice op with
182162 // smaller rank.
183163 extractedSource = rewriter.create <InsertStridedSliceOp>(
@@ -186,7 +166,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
186166 getI64SubArray (op.getStrides (), /* dropFront=*/ 1 ));
187167 }
188168 // 4. Insert the extractedSource into the res vector.
189- res = insertOne ( rewriter, loc, extractedSource, res, off);
169+ res = rewriter. create <InsertOp>( loc, extractedSource, res, off);
190170 }
191171
192172 rewriter.replaceOp (op, res);
@@ -277,8 +257,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
277257};
278258
279259// / RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
280- // / For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
281- // / rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
260+ // / For such cases, we can rewrite it to ExtractOp + lower rank
261+ // / ExtractStridedSliceOp + InsertOp for the n-D case.
282262class DecomposeNDExtractStridedSlice
283263 : public OpRewritePattern<ExtractStridedSliceOp> {
284264public:
@@ -317,12 +297,12 @@ class DecomposeNDExtractStridedSlice
317297 Value res = rewriter.create <SplatOp>(loc, dstType, zero);
318298 for (int64_t off = offset, e = offset + size * stride, idx = 0 ; off < e;
319299 off += stride, ++idx) {
320- Value one = extractOne ( rewriter, loc, op.getVector (), off);
300+ Value one = rewriter. create <ExtractOp>( loc, op.getVector (), off);
321301 Value extracted = rewriter.create <ExtractStridedSliceOp>(
322302 loc, one, getI64SubArray (op.getOffsets (), /* dropFront=*/ 1 ),
323303 getI64SubArray (op.getSizes (), /* dropFront=*/ 1 ),
324304 getI64SubArray (op.getStrides (), /* dropFront=*/ 1 ));
325- res = insertOne ( rewriter, loc, extracted, res, idx);
305+ res = rewriter. create <InsertOp>( loc, extracted, res, idx);
326306 }
327307 rewriter.replaceOp (op, res);
328308 return success ();
0 commit comments