Skip to content

Commit ddd3df9

Browse files
Implement pattern as folder function, remove oneuse conditional judgement, update tests.
1 parent 227158d commit ddd3df9

File tree

2 files changed

+71
-28
lines changed

2 files changed

+71
-28
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3334,29 +3334,6 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33343334
return success();
33353335
}
33363336
};
3337-
3338-
/// Pattern to rewrite a InsertOp(InsertOp) to InsertOp.
3339-
class InsertInsertToInsert final : public OpRewritePattern<InsertOp> {
3340-
public:
3341-
using OpRewritePattern::OpRewritePattern;
3342-
LogicalResult matchAndRewrite(InsertOp op,
3343-
PatternRewriter &rewriter) const override {
3344-
auto destInsert = op.getDest().getDefiningOp<InsertOp>();
3345-
if (!destInsert)
3346-
return failure();
3347-
3348-
if (!destInsert->hasOneUse())
3349-
return failure();
3350-
3351-
if (op.getMixedPosition() != destInsert.getMixedPosition())
3352-
return failure();
3353-
3354-
rewriter.replaceOpWithNewOp<InsertOp>(
3355-
op, op.getValueToStore(), destInsert.getDest(), op.getMixedPosition());
3356-
return success();
3357-
}
3358-
};
3359-
33603337
} // namespace
33613338

33623339
static Attribute
@@ -3409,13 +3386,26 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
34093386
return newAttr;
34103387
}
34113388

3389+
/// Folder to replace the `dest` operand of the insert op with the root dest of
3390+
/// the insert op use chain.
3391+
static Value foldInsertUseChain(InsertOp insertOp) {
3392+
auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3393+
if (!destInsert)
3394+
return {};
3395+
3396+
if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3397+
return {};
3398+
3399+
insertOp.setOperand(1, destInsert.getDest());
3400+
return insertOp.getResult();
3401+
}
3402+
34123403
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
34133404
MLIRContext *context) {
3414-
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3415-
InsertInsertToInsert>(context);
3405+
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
34163406
}
34173407

3418-
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3408+
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
34193409
// Do not create constants with more than `vectorSizeFoldThreashold` elements,
34203410
// unless the source vector constant has a single use.
34213411
constexpr int64_t vectorSizeFoldThreshold = 256;
@@ -3430,6 +3420,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
34303420
SmallVector<Value> operands = {getValueToStore(), getDest()};
34313421
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
34323422

3423+
if (auto res = foldInsertUseChain(*this))
3424+
return res;
34333425
if (auto res = foldPoisonIndexInsertExtractOp(
34343426
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
34353427
return res;

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3449,14 +3449,65 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
34493449

34503450
// -----
34513451

3452-
// CHECK-LABEL: @insert_insert_to_insert(
3452+
// CHECK-LABEL: @fold_insert_use_chain_static_pos(
34533453
// CHECK-SAME: %[[ARG:.*]]: vector<4xf32>,
34543454
// CHECK-SAME: %[[VAL:.*]]: f32) -> vector<4xf32> {
34553455
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0] : f32 into vector<4xf32>
34563456
// CHECK: return %[[RES]] : vector<4xf32>
3457-
func.func @insert_insert_to_insert(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
3457+
func.func @fold_insert_use_chain_static_pos(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
34583458
%v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
34593459
%v_1 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
34603460
%v_2 = vector.insert %value, %v_1[0] : f32 into vector<4xf32>
34613461
return %v_2 : vector<4xf32>
34623462
}
3463+
3464+
// -----
3465+
3466+
// CHECK-LABEL: @fold_insert_use_chain_dynamic_pos(
3467+
// CHECK-SAME: %[[ARG:.*]]: vector<4x4xf32>,
3468+
// CHECK-SAME: %[[VAL:.*]]: f32,
3469+
// CHECK-SAME: %[[POS:.*]]: index) -> vector<4x4xf32> {
3470+
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] {{\[}}%[[POS]], 0] : f32 into vector<4x4xf32>
3471+
// CHECK: return %[[RES]] : vector<4x4xf32>
3472+
func.func @fold_insert_use_chain_dynamic_pos(%arg : vector<4x4xf32>, %value : f32, %pos: index) -> vector<4x4xf32> {
3473+
%v_0 = vector.insert %value, %arg[%pos, 0] : f32 into vector<4x4xf32>
3474+
%v_1 = vector.insert %value, %v_0[%pos, 0] : f32 into vector<4x4xf32>
3475+
%v_2 = vector.insert %value, %v_1[%pos, 0] : f32 into vector<4x4xf32>
3476+
return %v_2 : vector<4x4xf32>
3477+
}
3478+
3479+
// -----
3480+
3481+
// CHECK-LABEL: @fold_insert_use_chain_add_float(
3482+
// CHECK-SAME: %[[VEC_0:.*]]: vector<4xf32>,
3483+
// CHECK-SAME: %[[VAL:.*]]: f32) -> vector<4xf32> {
3484+
// CHECK: %[[VEC_1:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
3485+
// CHECK: %[[VEC_2:.*]] = arith.addf %[[VEC_1]], %[[VEC_1]] : vector<4xf32>
3486+
// CHECK: %[[VEC_3:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
3487+
// CHECK: %[[VEC_4:.*]] = arith.addf %[[VEC_2]], %[[VEC_3]] : vector<4xf32>
3488+
// CHECK: return %[[VEC_4]] : vector<4xf32>
3489+
func.func @fold_insert_use_chain_add_float(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
3490+
%v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
3491+
%v_1 = arith.addf %v_0, %v_0 : vector<4xf32>
3492+
%v_2 = vector.insert %value, %v_0[0] : f32 into vector<4xf32>
3493+
%v_3 = arith.addf %v_1, %v_2 : vector<4xf32>
3494+
return %v_3 : vector<4xf32>
3495+
}
3496+
3497+
// -----
3498+
3499+
// CHECK-LABEL: @fold_insert_use_chain_add_float_pos_mismatch(
3500+
// CHECK-SAME: %[[VEC_0:.*]]: vector<4xf32>,
3501+
// CHECK-SAME: %[[VAL:.*]]: f32) -> vector<4xf32> {
3502+
// CHECK: %[[VEC_1:.*]] = vector.insert %[[VAL]], %[[VEC_0]] [0] : f32 into vector<4xf32>
3503+
// CHECK: %[[VEC_2:.*]] = arith.addf %[[VEC_1]], %[[VEC_1]] : vector<4xf32>
3504+
// CHECK: %[[VEC_3:.*]] = vector.insert %[[VAL]], %[[VEC_1]] [1] : f32 into vector<4xf32>
3505+
// CHECK: %[[VEC_4:.*]] = arith.addf %[[VEC_2]], %[[VEC_3]] : vector<4xf32>
3506+
// CHECK: return %[[VEC_4]] : vector<4xf32>
3507+
func.func @fold_insert_use_chain_add_float_pos_mismatch(%v : vector<4xf32>, %value : f32) -> vector<4xf32> {
3508+
%v_0 = vector.insert %value, %v[0] : f32 into vector<4xf32>
3509+
%v_1 = arith.addf %v_0, %v_0 : vector<4xf32>
3510+
%v_2 = vector.insert %value, %v_0[1] : f32 into vector<4xf32>
3511+
%v_3 = arith.addf %v_1, %v_2 : vector<4xf32>
3512+
return %v_3 : vector<4xf32>
3513+
}

0 commit comments

Comments
 (0)