Skip to content

Commit c307f91

Browse files
committed
Revert "[mlir][Vector] add vector.insert canonicalization pattern to convert a chain of insertions to vector.from_elements (llvm#142944)"
This reverts commit b4c31dc.
1 parent bfab808 commit c307f91

File tree

6 files changed

+189
-334
lines changed

6 files changed

+189
-334
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -3286,18 +3286,6 @@ LogicalResult InsertOp::verify() {
32863286
return success();
32873287
}
32883288

3289-
// Calculate the linearized position of the continuous chunk of elements to
3290-
// insert, based on the shape of the value to insert and the positions to insert
3291-
// at.
3292-
static int64_t calculateInsertPosition(VectorType destTy,
3293-
ArrayRef<int64_t> positions) {
3294-
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3295-
assert(positions.size() <= completePositions.size() &&
3296-
"positions size must be less than or equal to destTy rank");
3297-
copy(positions, completePositions.begin());
3298-
return linearize(completePositions, computeStrides(destTy.getShape()));
3299-
}
3300-
33013289
namespace {
33023290

33033291
// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3335,132 +3323,6 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33353323
return success();
33363324
}
33373325
};
3338-
3339-
/// Pattern to optimize a chain of insertions.
3340-
///
3341-
/// This pattern identifies chains of vector.insert operations that:
3342-
/// 1. Only insert values at static positions.
3343-
/// 2. Completely initialize all elements in the resulting vector.
3344-
/// 3. All intermediate insert operations have only one use.
3345-
///
3346-
/// When these conditions are met, the entire chain can be replaced with a
3347-
/// single vector.from_elements operation.
3348-
///
3349-
/// To keep this pattern simple, and avoid spending too much time on matching
3350-
/// fragmented insert chains, this pattern only considers the last insert op in
3351-
/// the chain.
3352-
///
3353-
/// Example transformation:
3354-
/// %poison = ub.poison : vector<2xi32>
3355-
/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3356-
/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3357-
/// ->
3358-
/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3359-
class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3360-
public:
3361-
using OpRewritePattern::OpRewritePattern;
3362-
LogicalResult matchAndRewrite(InsertOp op,
3363-
PatternRewriter &rewriter) const override {
3364-
3365-
VectorType destTy = op.getDestVectorType();
3366-
if (destTy.isScalable())
3367-
return failure();
3368-
// Ensure this is the trailing vector.insert op in a chain of inserts.
3369-
for (Operation *user : op.getResult().getUsers())
3370-
if (auto insertOp = dyn_cast<InsertOp>(user))
3371-
if (insertOp.getDest() == op.getResult())
3372-
return failure();
3373-
3374-
InsertOp currentOp = op;
3375-
SmallVector<InsertOp> chainInsertOps;
3376-
while (currentOp) {
3377-
// Check cond 1: Dynamic position is not supported.
3378-
if (currentOp.hasDynamicPosition())
3379-
return failure();
3380-
3381-
chainInsertOps.push_back(currentOp);
3382-
currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3383-
// Check cond 3: Intermediate inserts have only one use to avoid an
3384-
// explosion of vectors.
3385-
if (currentOp && !currentOp->hasOneUse())
3386-
return failure();
3387-
}
3388-
3389-
int64_t vectorSize = destTy.getNumElements();
3390-
int64_t initializedCount = 0;
3391-
SmallVector<bool> initializedDestIdxs(vectorSize, false);
3392-
SmallVector<int64_t> pendingInsertPos;
3393-
SmallVector<int64_t> pendingInsertSize;
3394-
SmallVector<Value> pendingInsertValues;
3395-
3396-
for (auto insertOp : chainInsertOps) {
3397-
// This pattern can do nothing with poison index.
3398-
if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3399-
return failure();
3400-
3401-
// Calculate the linearized position for inserting elements.
3402-
int64_t insertBeginPosition =
3403-
calculateInsertPosition(destTy, insertOp.getStaticPosition());
3404-
3405-
// The valueToStore operand may be a vector or a scalar. Need to handle
3406-
// both cases.
3407-
int64_t insertSize = 1;
3408-
if (auto srcVectorType =
3409-
llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3410-
insertSize = srcVectorType.getNumElements();
3411-
3412-
assert(insertBeginPosition + insertSize <= vectorSize &&
3413-
"insert would overflow the vector");
3414-
3415-
for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3416-
insertBeginPosition + insertSize)) {
3417-
if (initializedDestIdxs[index])
3418-
continue;
3419-
initializedDestIdxs[index] = true;
3420-
++initializedCount;
3421-
}
3422-
3423-
// Defer the creation of ops before we can make sure the pattern can
3424-
// succeed.
3425-
pendingInsertPos.push_back(insertBeginPosition);
3426-
pendingInsertSize.push_back(insertSize);
3427-
pendingInsertValues.push_back(insertOp.getValueToStore());
3428-
3429-
if (initializedCount == vectorSize)
3430-
break;
3431-
}
3432-
3433-
// Check cond 2: all positions must be initialized.
3434-
if (initializedCount != vectorSize)
3435-
return failure();
3436-
3437-
SmallVector<Value> elements(vectorSize);
3438-
for (auto [insertBeginPosition, insertSize, valueToStore] :
3439-
llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3440-
pendingInsertValues))) {
3441-
auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3442-
3443-
if (!srcVectorType) {
3444-
elements[insertBeginPosition] = valueToStore;
3445-
continue;
3446-
}
3447-
3448-
SmallVector<Type> elementToInsertTypes(insertSize,
3449-
srcVectorType.getElementType());
3450-
// Get all elements from the vector in row-major order.
3451-
auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
3452-
op.getLoc(), elementToInsertTypes, valueToStore);
3453-
for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3454-
elements[insertBeginPosition + linearIdx] =
3455-
elementsToInsert.getResult(linearIdx);
3456-
}
3457-
}
3458-
3459-
rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3460-
return success();
3461-
}
3462-
};
3463-
34643326
} // namespace
34653327

34663328
static Attribute
@@ -3487,9 +3349,13 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
34873349
!insertOp->hasOneUse())
34883350
return {};
34893351

3490-
// Calculate the linearized position for inserting elements.
3352+
// Calculate the linearized position of the continuous chunk of elements to
3353+
// insert.
3354+
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3355+
copy(insertOp.getStaticPosition(), completePositions.begin());
34913356
int64_t insertBeginPosition =
3492-
calculateInsertPosition(destTy, insertOp.getStaticPosition());
3357+
linearize(completePositions, computeStrides(destTy.getShape()));
3358+
34933359
SmallVector<Attribute> insertedValues;
34943360
Type destEltType = destTy.getElementType();
34953361

@@ -3525,8 +3391,7 @@ static Value foldInsertUseChain(InsertOp insertOp) {
35253391

35263392
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
35273393
MLIRContext *context) {
3528-
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3529-
InsertChainFullyInitialized>(context);
3394+
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
35303395
}
35313396

35323397
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {

mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,20 @@ func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32
8383
// CHECK-LABEL: @transpose
8484
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
8585
func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
86+
// CHECK: %[[UB:.*]] = ub.poison : vector<2xi32>
8687
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
88+
// CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[UB]] [0] : i32 into vector<2xi32>
8789
// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
88-
// CHECK: %[[FROM_ELEMENTS0:.*]] = vector.from_elements %[[EXTRACT0]], %[[EXTRACT1]] : vector<2xi32>
90+
// CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
8991
// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
92+
// CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[UB]] [0] : i32 into vector<2xi32>
9093
// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
91-
// CHECK: %[[FROM_ELEMENTS1:.*]] = vector.from_elements %[[EXTRACT2]], %[[EXTRACT3]] : vector<2xi32>
94+
// CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
9295
// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
96+
// CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[UB]] [0] : i32 into vector<2xi32>
9397
// CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
94-
// CHECK: %[[FROM_ELEMENTS2:.*]] = vector.from_elements %[[EXTRACT4]], %[[EXTRACT5]] : vector<2xi32>
95-
// CHECK: return %[[FROM_ELEMENTS0]], %[[FROM_ELEMENTS1]], %[[FROM_ELEMENTS2]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
98+
// CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
99+
// CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
96100
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
97101
return %0 : vector<3x2xi32>
98102
}

0 commit comments

Comments
 (0)