Skip to content

Commit addf878

Browse files
committed
[SDAG] Teach FoldConstantArithmetic to match splats inserted into vectors
This teaches FoldConstantArithmetic to match `insert_subvector undef, (splat X), N2` as a splat of X. This pattern can occur for scalable vectors when a fixed-length splat is inserted into an undef vector. This allows the cases in `fixed-subvector-insert-into-scalable.ll` to be constant-folded (where previously they would all be computed at runtime).
1 parent 1e605fc commit addf878

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7338,16 +7338,23 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
73387338
Op.getValueType().getVectorElementCount() == NumElts;
73397339
};
73407340

7341-
auto IsBuildVectorSplatVectorOrUndef = [](const SDValue &Op) {
7341+
// UNDEF: folds to undef
7342+
// BUILD_VECTOR: may have constant elements
7343+
// SPLAT_VECTOR: could be a splat of a constant
7344+
// INSERT_SUBVECTOR: could be inserting a constant splat into an undef vector
7345+
// - This pattern occurs when a fixed-length vector splat is inserted into
7346+
// a scalable vector
7347+
auto VectorOpMayConstantFold = [](const SDValue &Op) {
73427348
return Op.isUndef() || Op.getOpcode() == ISD::CONDCODE ||
73437349
Op.getOpcode() == ISD::BUILD_VECTOR ||
7344-
Op.getOpcode() == ISD::SPLAT_VECTOR;
7350+
Op.getOpcode() == ISD::SPLAT_VECTOR ||
7351+
Op.getOpcode() == ISD::INSERT_SUBVECTOR;
73457352
};
73467353

73477354
// All operands must be vector types with the same number of elements as
73487355
// the result type and must be either UNDEF or a build/splat vector
73497356
// or UNDEF scalars.
7350-
if (!llvm::all_of(Ops, IsBuildVectorSplatVectorOrUndef) ||
7357+
if (!llvm::all_of(Ops, VectorOpMayConstantFold) ||
73517358
!llvm::all_of(Ops, IsScalarOrSameVectorSize))
73527359
return SDValue();
73537360

@@ -7374,22 +7381,39 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
73747381
// a combination of BUILD_VECTOR and SPLAT_VECTOR.
73757382
unsigned NumVectorElts = NumElts.isScalable() ? 1 : NumElts.getFixedValue();
73767383

7384+
// Preprocess insert_subvector to avoid repeatedly matching the splat.
7385+
SmallVector<SDValue, 4> PreprocessedOps;
7386+
for (SDValue Op : Ops) {
7387+
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR) {
7388+
// match: `insert_subvector undef, (splat X), N2` as `splat X`
7389+
SDValue N0 = Op.getOperand(0);
7390+
auto* BV = dyn_cast<BuildVectorSDNode>(Op.getOperand(1));
7391+
if (!N0.isUndef() || !BV || !(Op = BV->getSplatValue()))
7392+
return SDValue();
7393+
}
7394+
PreprocessedOps.push_back(Op);
7395+
}
7396+
73777397
// Constant fold each scalar lane separately.
73787398
SmallVector<SDValue, 4> ScalarResults;
73797399
for (unsigned I = 0; I != NumVectorElts; I++) {
73807400
SmallVector<SDValue, 4> ScalarOps;
7381-
for (SDValue Op : Ops) {
7401+
for (SDValue Op : PreprocessedOps) {
73827402
EVT InSVT = Op.getValueType().getScalarType();
73837403
if (Op.getOpcode() != ISD::BUILD_VECTOR &&
7384-
Op.getOpcode() != ISD::SPLAT_VECTOR) {
7404+
Op.getOpcode() != ISD::SPLAT_VECTOR &&
7405+
Op.getOpcode() != ISD::INSERT_SUBVECTOR) {
73857406
if (Op.isUndef())
73867407
ScalarOps.push_back(getUNDEF(InSVT));
73877408
else
73887409
ScalarOps.push_back(Op);
73897410
continue;
73907411
}
73917412

7392-
SDValue ScalarOp =
7413+
// insert_subvector has been preprocessed, so if it was of the form
7414+
// `insert_subvector undef, (splat X), N2`, it has been replaced with the
7415+
// splat value (X).
7416+
SDValue ScalarOp = Op.getOpcode() == ISD::INSERT_SUBVECTOR ? Op :
73937417
Op.getOperand(Op.getOpcode() == ISD::SPLAT_VECTOR ? 0 : I);
73947418
EVT ScalarVT = ScalarOp.getValueType();
73957419

llvm/test/CodeGen/AArch64/fixed-subvector-insert-into-scalable.ll

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,7 @@
44
define <vscale x 4 x i32> @insert_div() {
55
; CHECK-LABEL: insert_div:
66
; CHECK: // %bb.0: // %entry
7-
; CHECK-NEXT: mov w8, #43691 // =0xaaab
8-
; CHECK-NEXT: movi v0.4s, #9
9-
; CHECK-NEXT: ptrue p0.s
10-
; CHECK-NEXT: movk w8, #43690, lsl #16
11-
; CHECK-NEXT: mov z1.s, w8
12-
; CHECK-NEXT: umulh z0.s, p0/m, z0.s, z1.s
13-
; CHECK-NEXT: lsr z0.s, z0.s, #1
7+
; CHECK-NEXT: mov z0.s, #3 // =0x3
148
; CHECK-NEXT: ret
159
entry:
1610
%0 = tail call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> undef, <4 x i32> splat (i32 9), i64 0)
@@ -21,8 +15,7 @@ entry:
2115
define <vscale x 4 x i32> @insert_mul() {
2216
; CHECK-LABEL: insert_mul:
2317
; CHECK: // %bb.0: // %entry
24-
; CHECK-NEXT: movi v0.4s, #1
25-
; CHECK-NEXT: mul z0.s, z0.s, #7
18+
; CHECK-NEXT: mov z0.s, #7 // =0x7
2619
; CHECK-NEXT: ret
2720
entry:
2821
%0 = tail call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> undef, <4 x i32> splat (i32 1), i64 0)
@@ -33,8 +26,7 @@ entry:
3326
define <vscale x 4 x i32> @insert_add() {
3427
; CHECK-LABEL: insert_add:
3528
; CHECK: // %bb.0: // %entry
36-
; CHECK-NEXT: movi v0.4s, #5
37-
; CHECK-NEXT: add z0.s, z0.s, #11 // =0xb
29+
; CHECK-NEXT: mov z0.s, #16 // =0x10
3830
; CHECK-NEXT: ret
3931
entry:
4032
%0 = tail call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> undef, <4 x i32> splat (i32 5), i64 0)
@@ -45,8 +37,7 @@ entry:
4537
define <vscale x 4 x i32> @insert_sub() {
4638
; CHECK-LABEL: insert_sub:
4739
; CHECK: // %bb.0: // %entry
48-
; CHECK-NEXT: movi v0.4s, #11
49-
; CHECK-NEXT: sub z0.s, z0.s, #11 // =0xb
40+
; CHECK-NEXT: movi v0.2d, #0000000000000000
5041
; CHECK-NEXT: ret
5142
entry:
5243
%0 = tail call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> undef, <4 x i32> splat (i32 11), i64 0)

0 commit comments

Comments
 (0)