@@ -3286,18 +3286,6 @@ LogicalResult InsertOp::verify() {
32863286 return success ();
32873287}
32883288
3289- // Calculate the linearized position of the continuous chunk of elements to
3290- // insert, based on the shape of the value to insert and the positions to insert
3291- // at.
3292- static int64_t calculateInsertPosition (VectorType destTy,
3293- ArrayRef<int64_t > positions) {
3294- llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3295- assert (positions.size () <= completePositions.size () &&
3296- " positions size must be less than or equal to destTy rank" );
3297- copy (positions, completePositions.begin ());
3298- return linearize (completePositions, computeStrides (destTy.getShape ()));
3299- }
3300-
33013289namespace {
33023290
33033291// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3335,132 +3323,6 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33353323 return success ();
33363324 }
33373325};
3338-
3339- // / Pattern to optimize a chain of insertions.
3340- // /
3341- // / This pattern identifies chains of vector.insert operations that:
3342- // / 1. Only insert values at static positions.
3343- // / 2. Completely initialize all elements in the resulting vector.
3344- // / 3. All intermediate insert operations have only one use.
3345- // /
3346- // / When these conditions are met, the entire chain can be replaced with a
3347- // / single vector.from_elements operation.
3348- // /
3349- // / To keep this pattern simple, and avoid spending too much time on matching
3350- // / fragmented insert chains, this pattern only considers the last insert op in
3351- // / the chain.
3352- // /
3353- // / Example transformation:
3354- // / %poison = ub.poison : vector<2xi32>
3355- // / %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3356- // / %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3357- // / ->
3358- // / %result = vector.from_elements %c1, %c2 : vector<2xi32>
3359- class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3360- public:
3361- using OpRewritePattern::OpRewritePattern;
3362- LogicalResult matchAndRewrite (InsertOp op,
3363- PatternRewriter &rewriter) const override {
3364-
3365- VectorType destTy = op.getDestVectorType ();
3366- if (destTy.isScalable ())
3367- return failure ();
3368- // Ensure this is the trailing vector.insert op in a chain of inserts.
3369- for (Operation *user : op.getResult ().getUsers ())
3370- if (auto insertOp = dyn_cast<InsertOp>(user))
3371- if (insertOp.getDest () == op.getResult ())
3372- return failure ();
3373-
3374- InsertOp currentOp = op;
3375- SmallVector<InsertOp> chainInsertOps;
3376- while (currentOp) {
3377- // Check cond 1: Dynamic position is not supported.
3378- if (currentOp.hasDynamicPosition ())
3379- return failure ();
3380-
3381- chainInsertOps.push_back (currentOp);
3382- currentOp = currentOp.getDest ().getDefiningOp <InsertOp>();
3383- // Check cond 3: Intermediate inserts have only one use to avoid an
3384- // explosion of vectors.
3385- if (currentOp && !currentOp->hasOneUse ())
3386- return failure ();
3387- }
3388-
3389- int64_t vectorSize = destTy.getNumElements ();
3390- int64_t initializedCount = 0 ;
3391- SmallVector<bool > initializedDestIdxs (vectorSize, false );
3392- SmallVector<int64_t > pendingInsertPos;
3393- SmallVector<int64_t > pendingInsertSize;
3394- SmallVector<Value> pendingInsertValues;
3395-
3396- for (auto insertOp : chainInsertOps) {
3397- // This pattern can do nothing with poison index.
3398- if (is_contained (insertOp.getStaticPosition (), InsertOp::kPoisonIndex ))
3399- return failure ();
3400-
3401- // Calculate the linearized position for inserting elements.
3402- int64_t insertBeginPosition =
3403- calculateInsertPosition (destTy, insertOp.getStaticPosition ());
3404-
3405- // The valueToStore operand may be a vector or a scalar. Need to handle
3406- // both cases.
3407- int64_t insertSize = 1 ;
3408- if (auto srcVectorType =
3409- llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType ()))
3410- insertSize = srcVectorType.getNumElements ();
3411-
3412- assert (insertBeginPosition + insertSize <= vectorSize &&
3413- " insert would overflow the vector" );
3414-
3415- for (auto index : llvm::seq<int64_t >(insertBeginPosition,
3416- insertBeginPosition + insertSize)) {
3417- if (initializedDestIdxs[index])
3418- continue ;
3419- initializedDestIdxs[index] = true ;
3420- ++initializedCount;
3421- }
3422-
3423- // Defer the creation of ops before we can make sure the pattern can
3424- // succeed.
3425- pendingInsertPos.push_back (insertBeginPosition);
3426- pendingInsertSize.push_back (insertSize);
3427- pendingInsertValues.push_back (insertOp.getValueToStore ());
3428-
3429- if (initializedCount == vectorSize)
3430- break ;
3431- }
3432-
3433- // Check cond 2: all positions must be initialized.
3434- if (initializedCount != vectorSize)
3435- return failure ();
3436-
3437- SmallVector<Value> elements (vectorSize);
3438- for (auto [insertBeginPosition, insertSize, valueToStore] :
3439- llvm::reverse (llvm::zip (pendingInsertPos, pendingInsertSize,
3440- pendingInsertValues))) {
3441- auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType ());
3442-
3443- if (!srcVectorType) {
3444- elements[insertBeginPosition] = valueToStore;
3445- continue ;
3446- }
3447-
3448- SmallVector<Type> elementToInsertTypes (insertSize,
3449- srcVectorType.getElementType ());
3450- // Get all elements from the vector in row-major order.
3451- auto elementsToInsert = rewriter.create <vector::ToElementsOp>(
3452- op.getLoc (), elementToInsertTypes, valueToStore);
3453- for (int64_t linearIdx = 0 ; linearIdx < insertSize; linearIdx++) {
3454- elements[insertBeginPosition + linearIdx] =
3455- elementsToInsert.getResult (linearIdx);
3456- }
3457- }
3458-
3459- rewriter.replaceOpWithNewOp <vector::FromElementsOp>(op, destTy, elements);
3460- return success ();
3461- }
3462- };
3463-
34643326} // namespace
34653327
34663328static Attribute
@@ -3487,9 +3349,13 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
34873349 !insertOp->hasOneUse ())
34883350 return {};
34893351
3490- // Calculate the linearized position for inserting elements.
3352+ // Calculate the linearized position of the continuous chunk of elements to
3353+ // insert.
3354+ llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3355+ copy (insertOp.getStaticPosition (), completePositions.begin ());
34913356 int64_t insertBeginPosition =
3492- calculateInsertPosition (destTy, insertOp.getStaticPosition ());
3357+ linearize (completePositions, computeStrides (destTy.getShape ()));
3358+
34933359 SmallVector<Attribute> insertedValues;
34943360 Type destEltType = destTy.getElementType ();
34953361
@@ -3525,8 +3391,7 @@ static Value foldInsertUseChain(InsertOp insertOp) {
35253391
35263392void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
35273393 MLIRContext *context) {
3528- results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3529- InsertChainFullyInitialized>(context);
3394+ results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
35303395}
35313396
35323397OpFoldResult InsertOp::fold (FoldAdaptor adaptor) {
0 commit comments