@@ -3290,18 +3290,6 @@ LogicalResult InsertOp::verify() {
32903290 return success ();
32913291}
32923292
3293- // Calculate the linearized position of the continuous chunk of elements to
3294- // insert, based on the shape of the value to insert and the positions to insert
3295- // at.
3296- static int64_t calculateInsertPosition (VectorType destTy,
3297- ArrayRef<int64_t > positions) {
3298- llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3299- assert (positions.size () <= completePositions.size () &&
3300- " positions size must be less than or equal to destTy rank" );
3301- copy (positions, completePositions.begin ());
3302- return linearize (completePositions, computeStrides (destTy.getShape ()));
3303- }
3304-
33053293namespace {
33063294
33073295// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3339,132 +3327,6 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33393327 return success ();
33403328 }
33413329};
3342-
3343- // / Pattern to optimize a chain of insertions.
3344- // /
3345- // / This pattern identifies chains of vector.insert operations that:
3346- // / 1. Only insert values at static positions.
3347- // / 2. Completely initialize all elements in the resulting vector.
3348- // / 3. All intermediate insert operations have only one use.
3349- // /
3350- // / When these conditions are met, the entire chain can be replaced with a
3351- // / single vector.from_elements operation.
3352- // /
3353- // / To keep this pattern simple, and avoid spending too much time on matching
3354- // / fragmented insert chains, this pattern only considers the last insert op in
3355- // / the chain.
3356- // /
3357- // / Example transformation:
3358- // / %poison = ub.poison : vector<2xi32>
3359- // / %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3360- // / %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3361- // / ->
3362- // / %result = vector.from_elements %c1, %c2 : vector<2xi32>
3363- class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3364- public:
3365- using OpRewritePattern::OpRewritePattern;
3366- LogicalResult matchAndRewrite (InsertOp op,
3367- PatternRewriter &rewriter) const override {
3368-
3369- VectorType destTy = op.getDestVectorType ();
3370- if (destTy.isScalable ())
3371- return failure ();
3372- // Ensure this is the trailing vector.insert op in a chain of inserts.
3373- for (Operation *user : op.getResult ().getUsers ())
3374- if (auto insertOp = dyn_cast<InsertOp>(user))
3375- if (insertOp.getDest () == op.getResult ())
3376- return failure ();
3377-
3378- InsertOp currentOp = op;
3379- SmallVector<InsertOp> chainInsertOps;
3380- while (currentOp) {
3381- // Check cond 1: Dynamic position is not supported.
3382- if (currentOp.hasDynamicPosition ())
3383- return failure ();
3384-
3385- chainInsertOps.push_back (currentOp);
3386- currentOp = currentOp.getDest ().getDefiningOp <InsertOp>();
3387- // Check cond 3: Intermediate inserts have only one use to avoid an
3388- // explosion of vectors.
3389- if (currentOp && !currentOp->hasOneUse ())
3390- return failure ();
3391- }
3392-
3393- int64_t vectorSize = destTy.getNumElements ();
3394- int64_t initializedCount = 0 ;
3395- SmallVector<bool > initializedDestIdxs (vectorSize, false );
3396- SmallVector<int64_t > pendingInsertPos;
3397- SmallVector<int64_t > pendingInsertSize;
3398- SmallVector<Value> pendingInsertValues;
3399-
3400- for (auto insertOp : chainInsertOps) {
3401- // This pattern can do nothing with poison index.
3402- if (is_contained (insertOp.getStaticPosition (), InsertOp::kPoisonIndex ))
3403- return failure ();
3404-
3405- // Calculate the linearized position for inserting elements.
3406- int64_t insertBeginPosition =
3407- calculateInsertPosition (destTy, insertOp.getStaticPosition ());
3408-
3409- // The valueToStore operand may be a vector or a scalar. Need to handle
3410- // both cases.
3411- int64_t insertSize = 1 ;
3412- if (auto srcVectorType =
3413- llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType ()))
3414- insertSize = srcVectorType.getNumElements ();
3415-
3416- assert (insertBeginPosition + insertSize <= vectorSize &&
3417- " insert would overflow the vector" );
3418-
3419- for (auto index : llvm::seq<int64_t >(insertBeginPosition,
3420- insertBeginPosition + insertSize)) {
3421- if (initializedDestIdxs[index])
3422- continue ;
3423- initializedDestIdxs[index] = true ;
3424- ++initializedCount;
3425- }
3426-
3427- // Defer the creation of ops before we can make sure the pattern can
3428- // succeed.
3429- pendingInsertPos.push_back (insertBeginPosition);
3430- pendingInsertSize.push_back (insertSize);
3431- pendingInsertValues.push_back (insertOp.getValueToStore ());
3432-
3433- if (initializedCount == vectorSize)
3434- break ;
3435- }
3436-
3437- // Check cond 2: all positions must be initialized.
3438- if (initializedCount != vectorSize)
3439- return failure ();
3440-
3441- SmallVector<Value> elements (vectorSize);
3442- for (auto [insertBeginPosition, insertSize, valueToStore] :
3443- llvm::reverse (llvm::zip (pendingInsertPos, pendingInsertSize,
3444- pendingInsertValues))) {
3445- auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType ());
3446-
3447- if (!srcVectorType) {
3448- elements[insertBeginPosition] = valueToStore;
3449- continue ;
3450- }
3451-
3452- SmallVector<Type> elementToInsertTypes (insertSize,
3453- srcVectorType.getElementType ());
3454- // Get all elements from the vector in row-major order.
3455- auto elementsToInsert = rewriter.create <vector::ToElementsOp>(
3456- op.getLoc (), elementToInsertTypes, valueToStore);
3457- for (int64_t linearIdx = 0 ; linearIdx < insertSize; linearIdx++) {
3458- elements[insertBeginPosition + linearIdx] =
3459- elementsToInsert.getResult (linearIdx);
3460- }
3461- }
3462-
3463- rewriter.replaceOpWithNewOp <vector::FromElementsOp>(op, destTy, elements);
3464- return success ();
3465- }
3466- };
3467-
34683330} // namespace
34693331
34703332static Attribute
@@ -3491,9 +3353,13 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
34913353 !insertOp->hasOneUse ())
34923354 return {};
34933355
3494- // Calculate the linearized position for inserting elements.
3356+ // Calculate the linearized position of the continuous chunk of elements to
3357+ // insert.
3358+ llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3359+ copy (insertOp.getStaticPosition (), completePositions.begin ());
34953360 int64_t insertBeginPosition =
3496- calculateInsertPosition (destTy, insertOp.getStaticPosition ());
3361+ linearize (completePositions, computeStrides (destTy.getShape ()));
3362+
34973363 SmallVector<Attribute> insertedValues;
34983364 Type destEltType = destTy.getElementType ();
34993365
@@ -3529,8 +3395,7 @@ static Value foldInsertUseChain(InsertOp insertOp) {
35293395
35303396void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
35313397 MLIRContext *context) {
3532- results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3533- InsertChainFullyInitialized>(context);
3398+ results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
35343399}
35353400
35363401OpFoldResult InsertOp::fold (FoldAdaptor adaptor) {
0 commit comments