Skip to content

Commit d17d32a

Browse files
[LLVM][ConstProp] Enable intrinsic simplifications for vector ConstantInt based operands. (llvm#159358)
Simplifcation of vector.reduce intrinsics are prevented by an early bailout for ConstantInt base operands. This PR removes the bailout and updates the tests to show matching output when -use-constant-int-for-*-splat is used.
1 parent d74c976 commit d17d32a

File tree

6 files changed

+111
-32
lines changed

6 files changed

+111
-32
lines changed

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2177,16 +2177,13 @@ Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
21772177
return PoisonValue::get(VT->getElementType());
21782178

21792179
// TODO: Handle undef.
2180-
if (!isa<ConstantVector>(Op) && !isa<ConstantDataVector>(Op))
2181-
return nullptr;
2182-
2183-
auto *EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(0U));
2180+
auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U));
21842181
if (!EltC)
21852182
return nullptr;
21862183

21872184
APInt Acc = EltC->getValue();
21882185
for (unsigned I = 1, E = VT->getNumElements(); I != E; I++) {
2189-
if (!(EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(I))))
2186+
if (!(EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(I))))
21902187
return nullptr;
21912188
const APInt &X = EltC->getValue();
21922189
switch (IID) {
@@ -3059,35 +3056,25 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
30593056
Val = Val | Val << 1;
30603057
return ConstantInt::get(Ty, Val);
30613058
}
3062-
3063-
default:
3064-
return nullptr;
30653059
}
30663060
}
30673061

3068-
switch (IntrinsicID) {
3069-
default: break;
3070-
case Intrinsic::vector_reduce_add:
3071-
case Intrinsic::vector_reduce_mul:
3072-
case Intrinsic::vector_reduce_and:
3073-
case Intrinsic::vector_reduce_or:
3074-
case Intrinsic::vector_reduce_xor:
3075-
case Intrinsic::vector_reduce_smin:
3076-
case Intrinsic::vector_reduce_smax:
3077-
case Intrinsic::vector_reduce_umin:
3078-
case Intrinsic::vector_reduce_umax:
3079-
if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0]))
3080-
return C;
3081-
break;
3082-
}
3083-
3084-
// Support ConstantVector in case we have an Undef in the top.
3085-
if (isa<ConstantVector>(Operands[0]) ||
3086-
isa<ConstantDataVector>(Operands[0]) ||
3087-
isa<ConstantAggregateZero>(Operands[0])) {
3062+
if (Operands[0]->getType()->isVectorTy()) {
30883063
auto *Op = cast<Constant>(Operands[0]);
30893064
switch (IntrinsicID) {
30903065
default: break;
3066+
case Intrinsic::vector_reduce_add:
3067+
case Intrinsic::vector_reduce_mul:
3068+
case Intrinsic::vector_reduce_and:
3069+
case Intrinsic::vector_reduce_or:
3070+
case Intrinsic::vector_reduce_xor:
3071+
case Intrinsic::vector_reduce_smin:
3072+
case Intrinsic::vector_reduce_smax:
3073+
case Intrinsic::vector_reduce_umin:
3074+
case Intrinsic::vector_reduce_umax:
3075+
if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0]))
3076+
return C;
3077+
break;
30913078
case Intrinsic::x86_sse_cvtss2si:
30923079
case Intrinsic::x86_sse_cvtss2si64:
30933080
case Intrinsic::x86_sse2_cvtsd2si:
@@ -3116,10 +3103,15 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
31163103
case Intrinsic::wasm_alltrue:
31173104
// Check each element individually
31183105
unsigned E = cast<FixedVectorType>(Op->getType())->getNumElements();
3119-
for (unsigned I = 0; I != E; ++I)
3120-
if (Constant *Elt = Op->getAggregateElement(I))
3121-
if (Elt->isZeroValue())
3122-
return ConstantInt::get(Ty, 0);
3106+
for (unsigned I = 0; I != E; ++I) {
3107+
Constant *Elt = Op->getAggregateElement(I);
3108+
// Return false as soon as we find a non-true element.
3109+
if (Elt && Elt->isZeroValue())
3110+
return ConstantInt::get(Ty, 0);
3111+
// Bail as soon as we find an element we cannot prove to be true.
3112+
if (!Elt || !isa<ConstantInt>(Elt))
3113+
return nullptr;
3114+
}
31233115

31243116
return ConstantInt::get(Ty, 1);
31253117
}

llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/any_all_true.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
22

33
; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
4+
; RUN: opt -passes=instsimplify -use-constant-int-for-fixed-length-splat -S < %s | FileCheck %s
45

56
; Test that intrinsics wasm call are constant folded
67

llvm/test/Transforms/InstSimplify/ConstProp/bitcount.ll

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
3+
; RUN: opt < %s -passes=instsimplify -use-constant-int-for-fixed-length-splat -use-constant-int-for-scalable-splat -S | FileCheck %s
34

45
declare i31 @llvm.ctpop.i31(i31 %val)
56
declare i32 @llvm.cttz.i32(i32 %val, i1)
@@ -120,6 +121,22 @@ define <2 x i31> @ctpop_vector() {
120121
ret <2 x i31> %x
121122
}
122123

124+
define <2 x i31> @ctpop_vector_splat_v2i31() {
125+
; CHECK-LABEL: @ctpop_vector_splat_v2i31(
126+
; CHECK-NEXT: ret <2 x i31> splat (i31 1)
127+
;
128+
%x = call <2 x i31> @llvm.ctpop.v2i31(<2 x i31> splat(i31 16))
129+
ret <2 x i31> %x
130+
}
131+
132+
define <vscale x 2 x i31> @ctpop_vector_splat_nxv2i31() {
133+
; CHECK-LABEL: @ctpop_vector_splat_nxv2i31(
134+
; CHECK-NEXT: ret <vscale x 2 x i31> splat (i31 1)
135+
;
136+
%x = call <vscale x 2 x i31> @llvm.ctpop.nxv2i31(<vscale x 2 x i31> splat(i31 16))
137+
ret <vscale x 2 x i31> %x
138+
}
139+
123140
define <2 x i31> @ctpop_vector_undef() {
124141
; CHECK-LABEL: @ctpop_vector_undef(
125142
; CHECK-NEXT: ret <2 x i31> zeroinitializer
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
3+
; RUN: opt < %s -passes=instsimplify -use-constant-int-for-fixed-length-splat -use-constant-int-for-scalable-splat -S | FileCheck %s
4+
5+
define i16 @W() {
6+
; CHECK-LABEL: define i16 @W() {
7+
; CHECK-NEXT: ret i16 -32768
8+
;
9+
%Z = call i16 @llvm.bitreverse.i16(i16 1)
10+
ret i16 %Z
11+
}
12+
13+
define i32 @X() {
14+
; CHECK-LABEL: define i32 @X() {
15+
; CHECK-NEXT: ret i32 -2147483648
16+
;
17+
%Z = call i32 @llvm.bitreverse.i32(i32 1)
18+
ret i32 %Z
19+
}
20+
21+
define i64 @Y() {
22+
; CHECK-LABEL: define i64 @Y() {
23+
; CHECK-NEXT: ret i64 -9223372036854775808
24+
;
25+
%Z = call i64 @llvm.bitreverse.i64(i64 1)
26+
ret i64 %Z
27+
}
28+
29+
define i80 @Z() {
30+
; CHECK-LABEL: define i80 @Z() {
31+
; CHECK-NEXT: ret i80 23777929115895377691656
32+
;
33+
%Z = call i80 @llvm.bitreverse.i80(i80 76151636403560493650080)
34+
ret i80 %Z
35+
}
36+
37+
define <4 x i32> @bitreverse_splat_v4i32() {
38+
; CHECK-LABEL: define <4 x i32> @bitreverse_splat_v4i32() {
39+
; CHECK-NEXT: ret <4 x i32> splat (i32 -2147483648)
40+
;
41+
%Z = call <4 x i32> @llvm.bitreverse.v4i32(<4 x i32> splat(i32 1))
42+
ret <4 x i32> %Z
43+
}
44+
45+
define <vscale x 4 x i32> @bitreverse_splat_nxv4i32() {
46+
; CHECK-LABEL: define <vscale x 4 x i32> @bitreverse_splat_nxv4i32() {
47+
; CHECK-NEXT: ret <vscale x 4 x i32> splat (i32 -2147483648)
48+
;
49+
%Z = call <vscale x 4 x i32> @llvm.bitreverse.v4i32(<vscale x 4 x i32> splat(i32 1))
50+
ret <vscale x 4 x i32> %Z
51+
}

llvm/test/Transforms/InstSimplify/ConstProp/bswap.ll

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
; bswap should be constant folded when it is passed a constant argument
33

44
; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
5+
; RUN: opt < %s -passes=instsimplify -use-constant-int-for-fixed-length-splat -use-constant-int-for-scalable-splat -S | FileCheck %s
56

67
declare i16 @llvm.bswap.i16(i16)
78

@@ -42,3 +43,19 @@ define i80 @Z() {
4243
%Z = call i80 @llvm.bswap.i80( i80 76151636403560493650080 )
4344
ret i80 %Z
4445
}
46+
47+
define <4 x i32> @bswap_splat_v4i32() {
48+
; CHECK-LABEL: define <4 x i32> @bswap_splat_v4i32() {
49+
; CHECK-NEXT: ret <4 x i32> splat (i32 16777216)
50+
;
51+
%Z = call <4 x i32> @llvm.bswap.v4i32(<4 x i32> splat(i32 1))
52+
ret <4 x i32> %Z
53+
}
54+
55+
define <vscale x 4 x i32> @bswap_splat_nxv4i32() {
56+
; CHECK-LABEL: define <vscale x 4 x i32> @bswap_splat_nxv4i32() {
57+
; CHECK-NEXT: ret <vscale x 4 x i32> splat (i32 16777216)
58+
;
59+
%Z = call <vscale x 4 x i32> @llvm.bswap.v4i32(<vscale x 4 x i32> splat(i32 1))
60+
ret <vscale x 4 x i32> %Z
61+
}

llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
3+
; RUN: opt < %s -passes=instsimplify -use-constant-int-for-fixed-length-splat -S | FileCheck %s
34

45
declare i32 @llvm.vector.reduce.add.v1i32(<1 x i32> %a)
56
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %a)

0 commit comments

Comments
 (0)