@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
172172 readOp.getPermutationMapAttr (), readOp.getPadding (), readOp.getMask (),
173173 readOp.getInBoundsAttr ());
174174
175- result = rewriter.create <vector::InsertStridedSliceOp>(
175+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
176176 loc, slicedRead, result, elementOffsets, strides);
177177 }
178178 rewriter.replaceOp (readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
213213 Value resultTensor;
214214 for (SmallVector<int64_t > elementOffsets :
215215 StaticTileOffsetRange (originalSize, *targetShape, loopOrder)) {
216- Value slicedVector = rewriter.create <vector::ExtractStridedSliceOp>(
216+ Value slicedVector = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
217217 loc, writeOp.getVector (), elementOffsets, *targetShape, strides);
218218 SmallVector<Value> indices =
219219 sliceTransferIndices (elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
289289 SmallVector<int64_t > operandShape = applyPermutationMap (
290290 permutationMap, ArrayRef<int64_t >(*targetShape));
291291 SmallVector<int64_t > operandStrides (operandOffets.size (), 1 );
292- slicesOperands[index] = rewriter.create <vector::ExtractStridedSliceOp>(
293- loc, operand, operandOffets, operandShape, operandStrides);
292+ slicesOperands[index] =
293+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
294+ loc, operand, operandOffets, operandShape, operandStrides);
294295 };
295296
296297 // Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
333334 loc, dstVecType, rewriter.getZeroAttr (dstVecType));
334335 for (const auto &it : accCache) {
335336 SmallVector<int64_t > dstStrides (it.first .size (), 1 );
336- result = rewriter.create <vector::InsertStridedSliceOp>(
337+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
337338 loc, it.second , result, it.first , dstStrides);
338339 }
339340 rewriter.replaceOp (contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
371372 StaticTileOffsetRange (originalSize, *targetShape)) {
372373 SmallVector<Value> operands;
373374 SmallVector<int64_t > operandStrides (offsets.size (), 1 );
374- Value slicedOperand = rewriter.create <vector::ExtractStridedSliceOp>(
375- loc, reductionOp.getSource (), offsets, *targetShape, operandStrides);
375+ Value slicedOperand =
376+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
377+ loc, reductionOp.getSource (), offsets, *targetShape,
378+ operandStrides);
376379 operands.push_back (slicedOperand);
377380 SmallVector<int64_t > dstShape;
378381 SmallVector<int64_t > destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
390393 if (accIt != accCache.end ())
391394 acc = accIt->second ;
392395 else
393- acc = rewriter.create <vector::ExtractStridedSliceOp>(
396+ acc = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
394397 loc, reductionOp.getAcc (), destOffset, dstShape, accStrides);
395398 operands.push_back (acc);
396399 auto targetType = VectorType::get (
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
406409 rewriter.getZeroAttr (reductionOp.getDestType ()));
407410 for (const auto &it : accCache) {
408411 SmallVector<int64_t > dstStrides (it.first .size (), 1 );
409- result = rewriter.create <vector::InsertStridedSliceOp>(
412+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
410413 loc, it.second , result, it.first , dstStrides);
411414 }
412415 rewriter.replaceOp (reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
453456 continue ;
454457 }
455458 extractOperands.push_back (
456- rewriter.create <vector::ExtractStridedSliceOp>(
459+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
457460 loc, operand.get (), offsets, *targetShape, strides));
458461 }
459462 Operation *newOp = cloneOpWithOperandsAndTypes (
460463 rewriter, loc, op, extractOperands, newVecType);
461- result = rewriter.create <vector::InsertStridedSliceOp>(
464+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
462465 loc, newOp->getResult (0 ), result, offsets, strides);
463466 }
464467 rewriter.replaceOp (op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
490493 for (SmallVector<int64_t > offsets :
491494 StaticTileOffsetRange (originalSize, *targetShape)) {
492495 SmallVector<int64_t > strides (offsets.size (), 1 );
493- Value slicedOperand = rewriter.create <vector::ExtractStridedSliceOp>(
494- loc, reductionOp.getVector (), offsets, *targetShape, strides);
496+ Value slicedOperand =
497+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
498+ loc, reductionOp.getVector (), offsets, *targetShape, strides);
495499 Operation *newOp = cloneOpWithOperandsAndTypes (
496500 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType ());
497501 Value result = newOp->getResult (0 );
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
548552 permutedOffsets[indices.value ()] = elementOffsets[indices.index ()];
549553 permutedShape[indices.value ()] = (*targetShape)[indices.index ()];
550554 }
551- Value slicedOperand = rewriter.create <vector::ExtractStridedSliceOp>(
552- loc, transposeOp.getVector (), permutedOffsets, permutedShape,
553- strides);
554- Value transposedSlice =
555- rewriter.create <vector::TransposeOp>(loc, slicedOperand, permutation);
556- result = rewriter.create <vector::InsertStridedSliceOp>(
555+ Value slicedOperand =
556+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
557+ loc, transposeOp.getVector (), permutedOffsets, permutedShape,
558+ strides);
559+ Value transposedSlice = rewriter.createOrFold <vector::TransposeOp>(
560+ loc, slicedOperand, permutation);
561+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
557562 loc, transposedSlice, result, elementOffsets, strides);
558563 }
559564 rewriter.replaceOp (transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
596601 // To get the unrolled gather, extract the same slice based on the
597602 // decomposed shape from each of the index, mask, and pass-through
598603 // vectors.
599- Value indexSubVec = rewriter.create <vector::ExtractStridedSliceOp>(
604+ Value indexSubVec = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
600605 loc, gatherOp.getIndexVec (), elementOffsets, *targetShape, strides);
601- Value maskSubVec = rewriter.create <vector::ExtractStridedSliceOp>(
606+ Value maskSubVec = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
602607 loc, gatherOp.getMask (), elementOffsets, *targetShape, strides);
603- Value passThruSubVec = rewriter.create <vector::ExtractStridedSliceOp>(
604- loc, gatherOp.getPassThru (), elementOffsets, *targetShape, strides);
608+ Value passThruSubVec =
609+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
610+ loc, gatherOp.getPassThru (), elementOffsets, *targetShape,
611+ strides);
605612 auto slicedGather = rewriter.create <vector::GatherOp>(
606613 loc, targetType, gatherOp.getBase (), gatherOp.getIndices (),
607614 indexSubVec, maskSubVec, passThruSubVec);
608615
609- result = rewriter.create <vector::InsertStridedSliceOp>(
616+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
610617 loc, slicedGather, result, elementOffsets, strides);
611618 }
612619 rewriter.replaceOp (gatherOp, result);
0 commit comments