Skip to content

Commit 22bcb47

Browse files
MacDueaokblast
authored andcommitted
[InstCombine] Constant fold binops through vector.insert (llvm#164624)
This patch improves constant folding through `llvm.vector.insert`. It does not change anything for fixed-length vectors (which can already be folded to ConstantVectors for these cases), but folds scalable vectors that otherwise would not be folded. These folds preserve the destination vector (which could be undef or poison), giving targets more freedom in lowering the operations.
1 parent 7ef2cf1 commit 22bcb47

File tree

3 files changed

+225
-0
lines changed

3 files changed

+225
-0
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,26 @@ struct constantexpr_match {
198198
/// expression.
199199
inline constantexpr_match m_ConstantExpr() { return constantexpr_match(); }
200200

201+
template <typename SubPattern_t> struct Splat_match {
202+
SubPattern_t SubPattern;
203+
Splat_match(const SubPattern_t &SP) : SubPattern(SP) {}
204+
205+
template <typename OpTy> bool match(OpTy *V) const {
206+
if (auto *C = dyn_cast<Constant>(V)) {
207+
auto *Splat = C->getSplatValue();
208+
return Splat ? SubPattern.match(Splat) : false;
209+
}
210+
// TODO: Extend to other cases (e.g. shufflevectors).
211+
return false;
212+
}
213+
};
214+
215+
/// Match a constant splat. TODO: Extend this to non-constant splats.
216+
template <typename T>
217+
inline Splat_match<T> m_ConstantSplat(const T &SubPattern) {
218+
return SubPattern;
219+
}
220+
201221
/// Match an arbitrary basic block value and ignore it.
202222
inline class_match<BasicBlock> m_BasicBlock() {
203223
return class_match<BasicBlock>();
@@ -2925,6 +2945,12 @@ inline typename m_Intrinsic_Ty<Opnd0>::Ty m_VecReverse(const Opnd0 &Op0) {
29252945
return m_Intrinsic<Intrinsic::vector_reverse>(Op0);
29262946
}
29272947

2948+
template <typename Opnd0, typename Opnd1, typename Opnd2>
2949+
inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2>::Ty
2950+
m_VectorInsert(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2) {
2951+
return m_Intrinsic<Intrinsic::vector_insert>(Op0, Op1, Op2);
2952+
}
2953+
29282954
//===----------------------------------------------------------------------===//
29292955
// Matchers for two-operands operators with the operators in either order
29302956
//

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2327,6 +2327,18 @@ Constant *InstCombinerImpl::unshuffleConstant(ArrayRef<int> ShMask, Constant *C,
23272327
return ConstantVector::get(NewVecC);
23282328
}
23292329

2330+
// Get the result of `Vector Op Splat` (or Splat Op Vector if \p SplatLHS).
2331+
static Constant *constantFoldBinOpWithSplat(unsigned Opcode, Constant *Vector,
2332+
Constant *Splat, bool SplatLHS,
2333+
const DataLayout &DL) {
2334+
ElementCount EC = cast<VectorType>(Vector->getType())->getElementCount();
2335+
Constant *LHS = ConstantVector::getSplat(EC, Splat);
2336+
Constant *RHS = Vector;
2337+
if (!SplatLHS)
2338+
std::swap(LHS, RHS);
2339+
return ConstantFoldBinaryOpOperands(Opcode, LHS, RHS, DL);
2340+
}
2341+
23302342
Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
23312343
if (!isa<VectorType>(Inst.getType()))
23322344
return nullptr;
@@ -2338,6 +2350,37 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
23382350
assert(cast<VectorType>(RHS->getType())->getElementCount() ==
23392351
cast<VectorType>(Inst.getType())->getElementCount());
23402352

2353+
auto foldConstantsThroughSubVectorInsertSplat =
2354+
[&](Value *MaybeSubVector, Value *MaybeSplat,
2355+
bool SplatLHS) -> Instruction * {
2356+
Value *Idx;
2357+
Constant *Splat, *SubVector, *Dest;
2358+
if (!match(MaybeSplat, m_ConstantSplat(m_Constant(Splat))) ||
2359+
!match(MaybeSubVector,
2360+
m_VectorInsert(m_Constant(Dest), m_Constant(SubVector),
2361+
m_Value(Idx))))
2362+
return nullptr;
2363+
SubVector =
2364+
constantFoldBinOpWithSplat(Opcode, SubVector, Splat, SplatLHS, DL);
2365+
Dest = constantFoldBinOpWithSplat(Opcode, Dest, Splat, SplatLHS, DL);
2366+
if (!SubVector || !Dest)
2367+
return nullptr;
2368+
auto *InsertVector =
2369+
Builder.CreateInsertVector(Dest->getType(), Dest, SubVector, Idx);
2370+
return replaceInstUsesWith(Inst, InsertVector);
2371+
};
2372+
2373+
// If one operand is a constant splat and the other operand is a
2374+
// `vector.insert` where both the destination and subvector are constant,
2375+
// apply the operation to both the destination and subvector, returning a new
2376+
// constant `vector.insert`. This helps constant folding for scalable vectors.
2377+
if (Instruction *Folded = foldConstantsThroughSubVectorInsertSplat(
2378+
/*MaybeSubVector=*/LHS, /*MaybeSplat=*/RHS, /*SplatLHS=*/false))
2379+
return Folded;
2380+
if (Instruction *Folded = foldConstantsThroughSubVectorInsertSplat(
2381+
/*MaybeSubVector=*/RHS, /*MaybeSplat=*/LHS, /*SplatLHS=*/true))
2382+
return Folded;
2383+
23412384
// If both operands of the binop are vector concatenations, then perform the
23422385
// narrow binop on each pair of the source operands followed by concatenation
23432386
// of the results.
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -S -passes=instcombine %s | FileCheck %s
3+
; RUN: opt -S -passes=instcombine %s \
4+
; RUN: -use-constant-int-for-fixed-length-splat \
5+
; RUN -use-constant-fp-for-fixed-length-splat \
6+
; RUN: -use-constant-int-for-scalable-splat \
7+
; RUN: -use-constant-fp-for-scalable-splat | FileCheck %s
8+
9+
define <vscale x 4 x i32> @insert_div() {
10+
; CHECK-LABEL: @insert_div(
11+
; CHECK-NEXT: entry:
12+
; CHECK-NEXT: [[DIV:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 3), i64 0)
13+
; CHECK-NEXT: ret <vscale x 4 x i32> [[DIV]]
14+
;
15+
entry:
16+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 9), i64 0)
17+
%div = udiv <vscale x 4 x i32> %0, splat (i32 3)
18+
ret <vscale x 4 x i32> %div
19+
}
20+
21+
define <vscale x 4 x i32> @insert_div_splat_lhs() {
22+
; CHECK-LABEL: @insert_div_splat_lhs(
23+
; CHECK-NEXT: entry:
24+
; CHECK-NEXT: [[DIV:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat (i32 5), <4 x i32> splat (i32 2), i64 0)
25+
; CHECK-NEXT: ret <vscale x 4 x i32> [[DIV]]
26+
;
27+
entry:
28+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat(i32 2), <4 x i32> splat (i32 5), i64 0)
29+
%div = udiv <vscale x 4 x i32> splat (i32 10), %0
30+
ret <vscale x 4 x i32> %div
31+
}
32+
33+
define <vscale x 4 x i32> @insert_div_mixed_splat() {
34+
; CHECK-LABEL: @insert_div_mixed_splat(
35+
; CHECK-NEXT: entry:
36+
; CHECK-NEXT: [[DIV:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat (i32 6), <4 x i32> splat (i32 3), i64 0)
37+
; CHECK-NEXT: ret <vscale x 4 x i32> [[DIV]]
38+
;
39+
entry:
40+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat (i32 18), <4 x i32> splat (i32 9), i64 0)
41+
%div = udiv <vscale x 4 x i32> %0, splat (i32 3)
42+
ret <vscale x 4 x i32> %div
43+
}
44+
45+
define <vscale x 4 x i32> @insert_mul() {
46+
; CHECK-LABEL: @insert_mul(
47+
; CHECK-NEXT: entry:
48+
; CHECK-NEXT: [[MUL:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 7), i64 4)
49+
; CHECK-NEXT: ret <vscale x 4 x i32> [[MUL]]
50+
;
51+
entry:
52+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 1), i64 4)
53+
%mul = mul <vscale x 4 x i32> %0, splat (i32 7)
54+
ret <vscale x 4 x i32> %mul
55+
}
56+
57+
define <vscale x 4 x i32> @insert_add() {
58+
; CHECK-LABEL: @insert_add(
59+
; CHECK-NEXT: entry:
60+
; CHECK-NEXT: [[ADD:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 16), i64 0)
61+
; CHECK-NEXT: ret <vscale x 4 x i32> [[ADD]]
62+
;
63+
entry:
64+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 5), i64 0)
65+
%add = add <vscale x 4 x i32> %0, splat (i32 11)
66+
ret <vscale x 4 x i32> %add
67+
}
68+
69+
define <vscale x 4 x i32> @insert_add_non_splat_subvector() {
70+
; CHECK-LABEL: @insert_add_non_splat_subvector(
71+
; CHECK-NEXT: entry:
72+
; CHECK-NEXT: [[ADD:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> <i32 101, i32 102, i32 103, i32 104>, i64 0)
73+
; CHECK-NEXT: ret <vscale x 4 x i32> [[ADD]]
74+
;
75+
entry:
76+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> <i32 1, i32 2, i32 3, i32 4>, i64 0)
77+
%add = add <vscale x 4 x i32> %0, splat (i32 100)
78+
ret <vscale x 4 x i32> %add
79+
}
80+
81+
define <vscale x 4 x float> @insert_add_fp() {
82+
; CHECK-LABEL: @insert_add_fp(
83+
; CHECK-NEXT: entry:
84+
; CHECK-NEXT: [[ADD:%.*]] = call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> splat (float 6.250000e+00), <4 x float> splat (float 5.500000e+00), i64 0)
85+
; CHECK-NEXT: ret <vscale x 4 x float> [[ADD]]
86+
;
87+
entry:
88+
%0 = call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> splat(float 1.25), <4 x float> splat (float 0.5), i64 0)
89+
%add = fadd <vscale x 4 x float> %0, splat (float 5.0)
90+
ret <vscale x 4 x float> %add
91+
}
92+
93+
define <vscale x 8 x i32> @insert_add_scalable_subvector() {
94+
; CHECK-LABEL: @insert_add_scalable_subvector(
95+
; CHECK-NEXT: entry:
96+
; CHECK-NEXT: [[ADD:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> splat (i32 20), <vscale x 4 x i32> splat (i32 -4), i64 0)
97+
; CHECK-NEXT: ret <vscale x 8 x i32> [[ADD]]
98+
;
99+
entry:
100+
%0 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> splat(i32 16), <vscale x 4 x i32> splat (i32 -8), i64 0)
101+
%add = add <vscale x 8 x i32> %0, splat (i32 4)
102+
ret <vscale x 8 x i32> %add
103+
}
104+
105+
define <vscale x 4 x i32> @insert_sub() {
106+
; CHECK-LABEL: @insert_sub(
107+
; CHECK-NEXT: entry:
108+
; CHECK-NEXT: [[SUB:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> zeroinitializer, i64 8)
109+
; CHECK-NEXT: ret <vscale x 4 x i32> [[SUB]]
110+
;
111+
entry:
112+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 11), i64 8)
113+
%sub = add <vscale x 4 x i32> %0, splat (i32 -11)
114+
ret <vscale x 4 x i32> %sub
115+
}
116+
117+
define <vscale x 4 x i32> @insert_and_partially_undef() {
118+
; CHECK-LABEL: @insert_and_partially_undef(
119+
; CHECK-NEXT: entry:
120+
; CHECK-NEXT: [[AND:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> zeroinitializer, <4 x i32> splat (i32 4), i64 0)
121+
; CHECK-NEXT: ret <vscale x 4 x i32> [[AND]]
122+
;
123+
entry:
124+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> undef, <4 x i32> splat (i32 6), i64 0)
125+
%and = and <vscale x 4 x i32> %0, splat (i32 4)
126+
ret <vscale x 4 x i32> %and
127+
}
128+
129+
define <vscale x 4 x i32> @insert_fold_chain() {
130+
; CHECK-LABEL: @insert_fold_chain(
131+
; CHECK-NEXT: entry:
132+
; CHECK-NEXT: [[ADD:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat (i32 11), <4 x i32> splat (i32 8), i64 0)
133+
; CHECK-NEXT: ret <vscale x 4 x i32> [[ADD]]
134+
;
135+
entry:
136+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat (i32 21), <4 x i32> splat (i32 12), i64 0)
137+
%div = udiv <vscale x 4 x i32> %0, splat (i32 3)
138+
%add = add <vscale x 4 x i32> %div, splat (i32 4)
139+
ret <vscale x 4 x i32> %add
140+
}
141+
142+
; TODO: This could be folded more.
143+
define <vscale x 4 x i32> @insert_add_both_insert_vector() {
144+
; CHECK-LABEL: @insert_add_both_insert_vector(
145+
; CHECK-NEXT: entry:
146+
; CHECK-NEXT: [[TMP0:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat (i32 10), <4 x i32> splat (i32 5), i64 0)
147+
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat (i32 -1), <4 x i32> splat (i32 2), i64 0)
148+
; CHECK-NEXT: [[ADD:%.*]] = add <vscale x 4 x i32> [[TMP0]], [[TMP1]]
149+
; CHECK-NEXT: ret <vscale x 4 x i32> [[ADD]]
150+
;
151+
entry:
152+
%0 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat(i32 10), <4 x i32> splat (i32 5), i64 0)
153+
%1 = call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> splat(i32 -1), <4 x i32> splat (i32 2), i64 0)
154+
%add = add <vscale x 4 x i32> %0, %1
155+
ret <vscale x 4 x i32> %add
156+
}

0 commit comments

Comments
 (0)