Skip to content

Commit 227158d

Browse files
Add InsertInsertToInsert to insert op canonicalize patterns
1 parent 1d5d125 commit 227158d

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3335,6 +3335,28 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33353335
}
33363336
};
33373337

3338+
/// Pattern to rewrite a InsertOp(InsertOp) to InsertOp.
3339+
class InsertInsertToInsert final : public OpRewritePattern<InsertOp> {
3340+
public:
3341+
using OpRewritePattern::OpRewritePattern;
3342+
LogicalResult matchAndRewrite(InsertOp op,
3343+
PatternRewriter &rewriter) const override {
3344+
auto destInsert = op.getDest().getDefiningOp<InsertOp>();
3345+
if (!destInsert)
3346+
return failure();
3347+
3348+
if (!destInsert->hasOneUse())
3349+
return failure();
3350+
3351+
if (op.getMixedPosition() != destInsert.getMixedPosition())
3352+
return failure();
3353+
3354+
rewriter.replaceOpWithNewOp<InsertOp>(
3355+
op, op.getValueToStore(), destInsert.getDest(), op.getMixedPosition());
3356+
return success();
3357+
}
3358+
};
3359+
33383360
} // namespace
33393361

33403362
static Attribute
@@ -3389,7 +3411,8 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
33893411

33903412
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
33913413
MLIRContext *context) {
3392-
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3414+
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3415+
InsertInsertToInsert>(context);
33933416
}
33943417

33953418
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3446,3 +3446,17 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
34463446
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
34473447
return %res : vector<4x1xi32>
34483448
}
3449+
3450+
// -----
3451+
3452+
// CHECK-LABEL: @insert_insert_to_insert(
3453+
// CHECK-SAME: %[[ARG:.*]]: vector<4xf32>,
3454+
// CHECK-SAME: %[[VAL:.*]]: f32) -> vector<4xf32> {
3455+
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0] : f32 into vector<4xf32>
3456+
// CHECK: return %[[RES]] : vector<4xf32>
3457+
func.func @insert_insert_to_insert(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
3458+
%v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
3459+
%v_1 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
3460+
%v_2 = vector.insert %value, %v_1[0] : f32 into vector<4xf32>
3461+
return %v_2 : vector<4xf32>
3462+
}

0 commit comments

Comments
 (0)