Skip to content

Commit 8243dfc

Browse files
Groverkssfabianmcg
authored andcommitted
Revert "[mlir][Vector] add vector.insert canonicalization pattern to convert a chain of insertions to vector.from_elements (llvm#142944)"
This reverts commit b4c31dc.
1 parent 9c7727c commit 8243dfc

File tree

6 files changed

+189
-334
lines changed

6 files changed

+189
-334
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -3290,18 +3290,6 @@ LogicalResult InsertOp::verify() {
32903290
return success();
32913291
}
32923292

3293-
// Calculate the linearized position of the continuous chunk of elements to
3294-
// insert, based on the shape of the value to insert and the positions to insert
3295-
// at.
3296-
static int64_t calculateInsertPosition(VectorType destTy,
3297-
ArrayRef<int64_t> positions) {
3298-
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3299-
assert(positions.size() <= completePositions.size() &&
3300-
"positions size must be less than or equal to destTy rank");
3301-
copy(positions, completePositions.begin());
3302-
return linearize(completePositions, computeStrides(destTy.getShape()));
3303-
}
3304-
33053293
namespace {
33063294

33073295
// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3339,132 +3327,6 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33393327
return success();
33403328
}
33413329
};
3342-
3343-
/// Pattern to optimize a chain of insertions.
3344-
///
3345-
/// This pattern identifies chains of vector.insert operations that:
3346-
/// 1. Only insert values at static positions.
3347-
/// 2. Completely initialize all elements in the resulting vector.
3348-
/// 3. All intermediate insert operations have only one use.
3349-
///
3350-
/// When these conditions are met, the entire chain can be replaced with a
3351-
/// single vector.from_elements operation.
3352-
///
3353-
/// To keep this pattern simple, and avoid spending too much time on matching
3354-
/// fragmented insert chains, this pattern only considers the last insert op in
3355-
/// the chain.
3356-
///
3357-
/// Example transformation:
3358-
/// %poison = ub.poison : vector<2xi32>
3359-
/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3360-
/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3361-
/// ->
3362-
/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3363-
class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3364-
public:
3365-
using OpRewritePattern::OpRewritePattern;
3366-
LogicalResult matchAndRewrite(InsertOp op,
3367-
PatternRewriter &rewriter) const override {
3368-
3369-
VectorType destTy = op.getDestVectorType();
3370-
if (destTy.isScalable())
3371-
return failure();
3372-
// Ensure this is the trailing vector.insert op in a chain of inserts.
3373-
for (Operation *user : op.getResult().getUsers())
3374-
if (auto insertOp = dyn_cast<InsertOp>(user))
3375-
if (insertOp.getDest() == op.getResult())
3376-
return failure();
3377-
3378-
InsertOp currentOp = op;
3379-
SmallVector<InsertOp> chainInsertOps;
3380-
while (currentOp) {
3381-
// Check cond 1: Dynamic position is not supported.
3382-
if (currentOp.hasDynamicPosition())
3383-
return failure();
3384-
3385-
chainInsertOps.push_back(currentOp);
3386-
currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3387-
// Check cond 3: Intermediate inserts have only one use to avoid an
3388-
// explosion of vectors.
3389-
if (currentOp && !currentOp->hasOneUse())
3390-
return failure();
3391-
}
3392-
3393-
int64_t vectorSize = destTy.getNumElements();
3394-
int64_t initializedCount = 0;
3395-
SmallVector<bool> initializedDestIdxs(vectorSize, false);
3396-
SmallVector<int64_t> pendingInsertPos;
3397-
SmallVector<int64_t> pendingInsertSize;
3398-
SmallVector<Value> pendingInsertValues;
3399-
3400-
for (auto insertOp : chainInsertOps) {
3401-
// This pattern can do nothing with poison index.
3402-
if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3403-
return failure();
3404-
3405-
// Calculate the linearized position for inserting elements.
3406-
int64_t insertBeginPosition =
3407-
calculateInsertPosition(destTy, insertOp.getStaticPosition());
3408-
3409-
// The valueToStore operand may be a vector or a scalar. Need to handle
3410-
// both cases.
3411-
int64_t insertSize = 1;
3412-
if (auto srcVectorType =
3413-
llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3414-
insertSize = srcVectorType.getNumElements();
3415-
3416-
assert(insertBeginPosition + insertSize <= vectorSize &&
3417-
"insert would overflow the vector");
3418-
3419-
for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3420-
insertBeginPosition + insertSize)) {
3421-
if (initializedDestIdxs[index])
3422-
continue;
3423-
initializedDestIdxs[index] = true;
3424-
++initializedCount;
3425-
}
3426-
3427-
// Defer the creation of ops before we can make sure the pattern can
3428-
// succeed.
3429-
pendingInsertPos.push_back(insertBeginPosition);
3430-
pendingInsertSize.push_back(insertSize);
3431-
pendingInsertValues.push_back(insertOp.getValueToStore());
3432-
3433-
if (initializedCount == vectorSize)
3434-
break;
3435-
}
3436-
3437-
// Check cond 2: all positions must be initialized.
3438-
if (initializedCount != vectorSize)
3439-
return failure();
3440-
3441-
SmallVector<Value> elements(vectorSize);
3442-
for (auto [insertBeginPosition, insertSize, valueToStore] :
3443-
llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3444-
pendingInsertValues))) {
3445-
auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3446-
3447-
if (!srcVectorType) {
3448-
elements[insertBeginPosition] = valueToStore;
3449-
continue;
3450-
}
3451-
3452-
SmallVector<Type> elementToInsertTypes(insertSize,
3453-
srcVectorType.getElementType());
3454-
// Get all elements from the vector in row-major order.
3455-
auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
3456-
op.getLoc(), elementToInsertTypes, valueToStore);
3457-
for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3458-
elements[insertBeginPosition + linearIdx] =
3459-
elementsToInsert.getResult(linearIdx);
3460-
}
3461-
}
3462-
3463-
rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3464-
return success();
3465-
}
3466-
};
3467-
34683330
} // namespace
34693331

34703332
static Attribute
@@ -3491,9 +3353,13 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
34913353
!insertOp->hasOneUse())
34923354
return {};
34933355

3494-
// Calculate the linearized position for inserting elements.
3356+
// Calculate the linearized position of the continuous chunk of elements to
3357+
// insert.
3358+
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3359+
copy(insertOp.getStaticPosition(), completePositions.begin());
34953360
int64_t insertBeginPosition =
3496-
calculateInsertPosition(destTy, insertOp.getStaticPosition());
3361+
linearize(completePositions, computeStrides(destTy.getShape()));
3362+
34973363
SmallVector<Attribute> insertedValues;
34983364
Type destEltType = destTy.getElementType();
34993365

@@ -3529,8 +3395,7 @@ static Value foldInsertUseChain(InsertOp insertOp) {
35293395

35303396
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
35313397
MLIRContext *context) {
3532-
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3533-
InsertChainFullyInitialized>(context);
3398+
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
35343399
}
35353400

35363401
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {

mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,20 @@ func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32
8383
// CHECK-LABEL: @transpose
8484
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
8585
func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
86+
// CHECK: %[[UB:.*]] = ub.poison : vector<2xi32>
8687
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
88+
// CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[UB]] [0] : i32 into vector<2xi32>
8789
// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
88-
// CHECK: %[[FROM_ELEMENTS0:.*]] = vector.from_elements %[[EXTRACT0]], %[[EXTRACT1]] : vector<2xi32>
90+
// CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
8991
// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
92+
// CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[UB]] [0] : i32 into vector<2xi32>
9093
// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
91-
// CHECK: %[[FROM_ELEMENTS1:.*]] = vector.from_elements %[[EXTRACT2]], %[[EXTRACT3]] : vector<2xi32>
94+
// CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
9295
// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
96+
// CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[UB]] [0] : i32 into vector<2xi32>
9397
// CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
94-
// CHECK: %[[FROM_ELEMENTS2:.*]] = vector.from_elements %[[EXTRACT4]], %[[EXTRACT5]] : vector<2xi32>
95-
// CHECK: return %[[FROM_ELEMENTS0]], %[[FROM_ELEMENTS1]], %[[FROM_ELEMENTS2]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
98+
// CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
99+
// CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
96100
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
97101
return %0 : vector<3x2xi32>
98102
}

0 commit comments

Comments
 (0)