Skip to content

Commit b4c31dc

Browse files
yangtetrisYang Baibanach-space
authored
[mlir][Vector] add vector.insert canonicalization pattern to convert a chain of insertions to vector.from_elements (llvm#142944)
## Description This change introduces a new canonicalization pattern for the MLIR Vector dialect that optimizes chains of insertions. The optimization identifies when a vector is **completely** initialized through a series of vector.insert operations and replaces the entire chain with a single `vector.from_elements` operation. Please be aware that the new pattern **doesn't** work for poison vectors where only **some** elements are set, as MLIR doesn't support partial poison vectors for now. **New Pattern: InsertChainFullyInitialized** * Detects chains of vector.insert operations. * Validates that all insertions are at static positions, and all intermediate insertions have only one use. * Ensures the entire vector is **completely** initialized. * Replaces the entire chain with a single vector.from_elementts operation. **Refactored Helper Function** * Extracted `calculateInsertPosition` from `foldDenseElementsAttrDestInsertOp` to avoid code duplication. ## Example ``` // Before: %v1 = vector.insert %c10, %v0[0] : i64 into vector<2xi64> %v2 = vector.insert %c20, %v1[1] : i64 into vector<2xi64> // After: %v2 = vector.from_elements %c10, %c20 : vector<2xi64> ``` It also works for multidimensional vectors. ``` // Before: %v1 = vector.insert %cv0, %v0[0] : vector<3xi64> into vector<2x3xi64> %v2 = vector.insert %cv1, %v1[1] : vector<3xi64> into vector<2x3xi64> // After: %0:3 = vector.to_elements %arg1 : vector<3xi64> %1:3 = vector.to_elements %arg2 : vector<3xi64> %v2 = vector.from_elements %0#0, %0#1, %0#2, %1#0, %1#1, %1#2 : vector<2x3xi64> ``` --------- Co-authored-by: Yang Bai <[email protected]> Co-authored-by: Andrzej Warzyński <[email protected]>
1 parent c65c0e8 commit b4c31dc

File tree

6 files changed

+334
-189
lines changed

6 files changed

+334
-189
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 142 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3286,6 +3286,18 @@ LogicalResult InsertOp::verify() {
32863286
return success();
32873287
}
32883288

3289+
// Calculate the linearized position of the continuous chunk of elements to
3290+
// insert, based on the shape of the value to insert and the positions to insert
3291+
// at.
3292+
static int64_t calculateInsertPosition(VectorType destTy,
3293+
ArrayRef<int64_t> positions) {
3294+
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3295+
assert(positions.size() <= completePositions.size() &&
3296+
"positions size must be less than or equal to destTy rank");
3297+
copy(positions, completePositions.begin());
3298+
return linearize(completePositions, computeStrides(destTy.getShape()));
3299+
}
3300+
32893301
namespace {
32903302

32913303
// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3323,6 +3335,132 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33233335
return success();
33243336
}
33253337
};
3338+
3339+
/// Pattern to optimize a chain of insertions.
3340+
///
3341+
/// This pattern identifies chains of vector.insert operations that:
3342+
/// 1. Only insert values at static positions.
3343+
/// 2. Completely initialize all elements in the resulting vector.
3344+
/// 3. All intermediate insert operations have only one use.
3345+
///
3346+
/// When these conditions are met, the entire chain can be replaced with a
3347+
/// single vector.from_elements operation.
3348+
///
3349+
/// To keep this pattern simple, and avoid spending too much time on matching
3350+
/// fragmented insert chains, this pattern only considers the last insert op in
3351+
/// the chain.
3352+
///
3353+
/// Example transformation:
3354+
/// %poison = ub.poison : vector<2xi32>
3355+
/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3356+
/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3357+
/// ->
3358+
/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3359+
class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3360+
public:
3361+
using OpRewritePattern::OpRewritePattern;
3362+
LogicalResult matchAndRewrite(InsertOp op,
3363+
PatternRewriter &rewriter) const override {
3364+
3365+
VectorType destTy = op.getDestVectorType();
3366+
if (destTy.isScalable())
3367+
return failure();
3368+
// Ensure this is the trailing vector.insert op in a chain of inserts.
3369+
for (Operation *user : op.getResult().getUsers())
3370+
if (auto insertOp = dyn_cast<InsertOp>(user))
3371+
if (insertOp.getDest() == op.getResult())
3372+
return failure();
3373+
3374+
InsertOp currentOp = op;
3375+
SmallVector<InsertOp> chainInsertOps;
3376+
while (currentOp) {
3377+
// Check cond 1: Dynamic position is not supported.
3378+
if (currentOp.hasDynamicPosition())
3379+
return failure();
3380+
3381+
chainInsertOps.push_back(currentOp);
3382+
currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3383+
// Check cond 3: Intermediate inserts have only one use to avoid an
3384+
// explosion of vectors.
3385+
if (currentOp && !currentOp->hasOneUse())
3386+
return failure();
3387+
}
3388+
3389+
int64_t vectorSize = destTy.getNumElements();
3390+
int64_t initializedCount = 0;
3391+
SmallVector<bool> initializedDestIdxs(vectorSize, false);
3392+
SmallVector<int64_t> pendingInsertPos;
3393+
SmallVector<int64_t> pendingInsertSize;
3394+
SmallVector<Value> pendingInsertValues;
3395+
3396+
for (auto insertOp : chainInsertOps) {
3397+
// This pattern can do nothing with poison index.
3398+
if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3399+
return failure();
3400+
3401+
// Calculate the linearized position for inserting elements.
3402+
int64_t insertBeginPosition =
3403+
calculateInsertPosition(destTy, insertOp.getStaticPosition());
3404+
3405+
// The valueToStore operand may be a vector or a scalar. Need to handle
3406+
// both cases.
3407+
int64_t insertSize = 1;
3408+
if (auto srcVectorType =
3409+
llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3410+
insertSize = srcVectorType.getNumElements();
3411+
3412+
assert(insertBeginPosition + insertSize <= vectorSize &&
3413+
"insert would overflow the vector");
3414+
3415+
for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3416+
insertBeginPosition + insertSize)) {
3417+
if (initializedDestIdxs[index])
3418+
continue;
3419+
initializedDestIdxs[index] = true;
3420+
++initializedCount;
3421+
}
3422+
3423+
// Defer the creation of ops before we can make sure the pattern can
3424+
// succeed.
3425+
pendingInsertPos.push_back(insertBeginPosition);
3426+
pendingInsertSize.push_back(insertSize);
3427+
pendingInsertValues.push_back(insertOp.getValueToStore());
3428+
3429+
if (initializedCount == vectorSize)
3430+
break;
3431+
}
3432+
3433+
// Check cond 2: all positions must be initialized.
3434+
if (initializedCount != vectorSize)
3435+
return failure();
3436+
3437+
SmallVector<Value> elements(vectorSize);
3438+
for (auto [insertBeginPosition, insertSize, valueToStore] :
3439+
llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3440+
pendingInsertValues))) {
3441+
auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3442+
3443+
if (!srcVectorType) {
3444+
elements[insertBeginPosition] = valueToStore;
3445+
continue;
3446+
}
3447+
3448+
SmallVector<Type> elementToInsertTypes(insertSize,
3449+
srcVectorType.getElementType());
3450+
// Get all elements from the vector in row-major order.
3451+
auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
3452+
op.getLoc(), elementToInsertTypes, valueToStore);
3453+
for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3454+
elements[insertBeginPosition + linearIdx] =
3455+
elementsToInsert.getResult(linearIdx);
3456+
}
3457+
}
3458+
3459+
rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3460+
return success();
3461+
}
3462+
};
3463+
33263464
} // namespace
33273465

33283466
static Attribute
@@ -3349,13 +3487,9 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
33493487
!insertOp->hasOneUse())
33503488
return {};
33513489

3352-
// Calculate the linearized position of the continuous chunk of elements to
3353-
// insert.
3354-
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3355-
copy(insertOp.getStaticPosition(), completePositions.begin());
3490+
// Calculate the linearized position for inserting elements.
33563491
int64_t insertBeginPosition =
3357-
linearize(completePositions, computeStrides(destTy.getShape()));
3358-
3492+
calculateInsertPosition(destTy, insertOp.getStaticPosition());
33593493
SmallVector<Attribute> insertedValues;
33603494
Type destEltType = destTy.getElementType();
33613495

@@ -3391,7 +3525,8 @@ static Value foldInsertUseChain(InsertOp insertOp) {
33913525

33923526
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
33933527
MLIRContext *context) {
3394-
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3528+
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3529+
InsertChainFullyInitialized>(context);
33953530
}
33963531

33973532
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {

mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,16 @@ func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32
8383
// CHECK-LABEL: @transpose
8484
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
8585
func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
86-
// CHECK: %[[UB:.*]] = ub.poison : vector<2xi32>
8786
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
88-
// CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[UB]] [0] : i32 into vector<2xi32>
8987
// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
90-
// CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
88+
// CHECK: %[[FROM_ELEMENTS0:.*]] = vector.from_elements %[[EXTRACT0]], %[[EXTRACT1]] : vector<2xi32>
9189
// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
92-
// CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[UB]] [0] : i32 into vector<2xi32>
9390
// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
94-
// CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
91+
// CHECK: %[[FROM_ELEMENTS1:.*]] = vector.from_elements %[[EXTRACT2]], %[[EXTRACT3]] : vector<2xi32>
9592
// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
96-
// CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[UB]] [0] : i32 into vector<2xi32>
9793
// CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
98-
// CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
99-
// CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
94+
// CHECK: %[[FROM_ELEMENTS2:.*]] = vector.from_elements %[[EXTRACT4]], %[[EXTRACT5]] : vector<2xi32>
95+
// CHECK: return %[[FROM_ELEMENTS0]], %[[FROM_ELEMENTS1]], %[[FROM_ELEMENTS2]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
10096
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
10197
return %0 : vector<3x2xi32>
10298
}

0 commit comments

Comments
 (0)