Skip to content

Commit 0e6d612

Browse files
authored
[AArch64] Improve select dagcombine (#169925)
An AnyOf reduction (aka vector.reduce.or) with a fixed-width vector is canonicalized to a bitcast of the mask vector to an integer of the same overall size, which is then compared against zero. If the scalar result of the bitcast is smaller than the element size of vectors being selected, we often end up with suboptimal codegen. This fixes the main cases, removing scalarized code.
1 parent 153c7e4 commit 0e6d612

File tree

3 files changed

+96
-36
lines changed

3 files changed

+96
-36
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26983,22 +26983,25 @@ static SDValue performSelectCombine(SDNode *N,
2698326983
assert((N0.getValueType() == MVT::i1 || N0.getValueType() == MVT::i32) &&
2698426984
"Scalar-SETCC feeding SELECT has unexpected result type!");
2698526985

26986-
// If NumMaskElts == 0, the comparison is larger than select result. The
26987-
// largest real NEON comparison is 64-bits per lane, which means the result is
26988-
// at most 32-bits and an illegal vector. Just bail out for now.
26989-
EVT SrcVT = N0.getOperand(0).getValueType();
26990-
2699126986
// Don't try to do this optimization when the setcc itself has i1 operands.
2699226987
// There are no legal vectors of i1, so this would be pointless. v1f16 is
2699326988
// ruled out to prevent the creation of setcc that need to be scalarized.
26989+
EVT SrcVT = N0.getOperand(0).getValueType();
2699426990
if (SrcVT == MVT::i1 ||
2699526991
(SrcVT.isFloatingPoint() && SrcVT.getSizeInBits() <= 16))
2699626992
return SDValue();
2699726993

26998-
int NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits();
26994+
// If NumMaskElts == 0, the comparison is larger than select result. The
26995+
// largest real NEON comparison is 64-bits per lane, which means the result is
26996+
// at most 32-bits and an illegal vector. Just bail out for now.
26997+
unsigned NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits();
2699926998
if (!ResVT.isVector() || NumMaskElts == 0)
2700026999
return SDValue();
2700127000

27001+
// Avoid creating vectors with excessive VFs before legalization.
27002+
if (DCI.isBeforeLegalize() && NumMaskElts != ResVT.getVectorNumElements())
27003+
return SDValue();
27004+
2700227005
SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumMaskElts);
2700327006
EVT CCVT = SrcVT.changeVectorElementTypeToInteger();
2700427007

llvm/test/CodeGen/AArch64/expand-select.ll

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,15 @@
44
define void @foo(i32 %In1, <2 x i128> %In2, <2 x i128> %In3, ptr %Out) {
55
; CHECK-LABEL: foo:
66
; CHECK: // %bb.0:
7-
; CHECK-NEXT: movi d0, #0000000000000000
8-
; CHECK-NEXT: and w8, w0, #0x1
9-
; CHECK-NEXT: ldr x11, [sp]
10-
; CHECK-NEXT: fmov s1, w8
11-
; CHECK-NEXT: ldp x8, x10, [sp, #8]
12-
; CHECK-NEXT: cmeq v0.4s, v1.4s, v0.4s
13-
; CHECK-NEXT: fmov w9, s0
14-
; CHECK-NEXT: tst w9, #0x1
15-
; CHECK-NEXT: csel x8, x5, x8, ne
16-
; CHECK-NEXT: csel x9, x4, x11, ne
17-
; CHECK-NEXT: stp x9, x8, [x10, #16]
18-
; CHECK-NEXT: csel x8, x3, x7, ne
19-
; CHECK-NEXT: csel x9, x2, x6, ne
20-
; CHECK-NEXT: stp x9, x8, [x10]
7+
; CHECK-NEXT: ldp x8, x9, [sp, #8]
8+
; CHECK-NEXT: tst w0, #0x1
9+
; CHECK-NEXT: ldr x10, [sp]
10+
; CHECK-NEXT: csel x8, x5, x8, eq
11+
; CHECK-NEXT: csel x10, x4, x10, eq
12+
; CHECK-NEXT: stp x10, x8, [x9, #16]
13+
; CHECK-NEXT: csel x8, x3, x7, eq
14+
; CHECK-NEXT: csel x10, x2, x6, eq
15+
; CHECK-NEXT: stp x10, x8, [x9]
2116
; CHECK-NEXT: ret
2217
%cond = and i32 %In1, 1
2318
%cbool = icmp eq i32 %cond, 0
@@ -31,22 +26,17 @@ define void @foo(i32 %In1, <2 x i128> %In2, <2 x i128> %In3, ptr %Out) {
3126
define void @bar(i32 %In1, <2 x i96> %In2, <2 x i96> %In3, ptr %Out) {
3227
; CHECK-LABEL: bar:
3328
; CHECK: // %bb.0:
34-
; CHECK-NEXT: movi d0, #0000000000000000
35-
; CHECK-NEXT: and w8, w0, #0x1
36-
; CHECK-NEXT: ldr x10, [sp, #16]
37-
; CHECK-NEXT: fmov s1, w8
38-
; CHECK-NEXT: cmeq v0.4s, v1.4s, v0.4s
39-
; CHECK-NEXT: fmov w9, s0
40-
; CHECK-NEXT: tst w9, #0x1
41-
; CHECK-NEXT: ldp x8, x9, [sp]
42-
; CHECK-NEXT: csel x11, x2, x6, ne
43-
; CHECK-NEXT: str x11, [x10]
44-
; CHECK-NEXT: csel x8, x4, x8, ne
45-
; CHECK-NEXT: stur x8, [x10, #12]
46-
; CHECK-NEXT: csel x8, x5, x9, ne
47-
; CHECK-NEXT: csel x9, x3, x7, ne
48-
; CHECK-NEXT: str w8, [x10, #20]
49-
; CHECK-NEXT: str w9, [x10, #8]
29+
; CHECK-NEXT: ldp x8, x10, [sp]
30+
; CHECK-NEXT: tst w0, #0x1
31+
; CHECK-NEXT: ldr x9, [sp, #16]
32+
; CHECK-NEXT: csel x11, x2, x6, eq
33+
; CHECK-NEXT: csel x8, x4, x8, eq
34+
; CHECK-NEXT: str x11, [x9]
35+
; CHECK-NEXT: stur x8, [x9, #12]
36+
; CHECK-NEXT: csel x8, x5, x10, eq
37+
; CHECK-NEXT: csel x10, x3, x7, eq
38+
; CHECK-NEXT: str w8, [x9, #20]
39+
; CHECK-NEXT: str w10, [x9, #8]
5040
; CHECK-NEXT: ret
5141
%cond = and i32 %In1, 1
5242
%cbool = icmp eq i32 %cond, 0
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc %s -o - | FileCheck %s
3+
target triple = "aarch64-linux-gnu"
4+
5+
;; An 'AnyOf' reduction (vector.reduce.or) is instcombined to a bitcast to an
6+
;; integer of a bitwidth equal to the number of lanes being reduced, then
7+
;; compared against zero. To select between vectors for NEON, we then need to
8+
;; broadcast the result, but we must be careful when the bitwidth of the scalar
9+
;; result is smaller than the element size of the vectors being selected. We
10+
;; don't want to end up with scalarization.
11+
12+
define <4 x i32> @any_of_select_vf4(<4 x i32> %mask, <4 x i32> %a, <4 x i32> %b) {
13+
; CHECK-LABEL: any_of_select_vf4:
14+
; CHECK: // %bb.0:
15+
; CHECK-NEXT: cmlt v0.4s, v0.4s, #0
16+
; CHECK-NEXT: umaxv s0, v0.4s
17+
; CHECK-NEXT: fmov w8, s0
18+
; CHECK-NEXT: tst w8, #0x1
19+
; CHECK-NEXT: csetm w8, ne
20+
; CHECK-NEXT: dup v0.4s, w8
21+
; CHECK-NEXT: bsl v0.16b, v2.16b, v1.16b
22+
; CHECK-NEXT: ret
23+
%cmp = icmp slt <4 x i32> %mask, zeroinitializer
24+
%cmp.bc = bitcast <4 x i1> %cmp to i4
25+
%cmp.bc.not = icmp eq i4 %cmp.bc, 0
26+
%res = select i1 %cmp.bc.not, <4 x i32> %a, <4 x i32> %b
27+
ret <4 x i32> %res
28+
}
29+
30+
define <2 x i64> @any_of_select_vf2(<2 x i64> %mask, <2 x i64> %a, <2 x i64> %b) {
31+
; CHECK-LABEL: any_of_select_vf2:
32+
; CHECK: // %bb.0:
33+
; CHECK-NEXT: cmlt v0.2d, v0.2d, #0
34+
; CHECK-NEXT: umaxv s0, v0.4s
35+
; CHECK-NEXT: fmov w8, s0
36+
; CHECK-NEXT: tst w8, #0x1
37+
; CHECK-NEXT: csetm x8, ne
38+
; CHECK-NEXT: dup v0.2d, x8
39+
; CHECK-NEXT: bsl v0.16b, v2.16b, v1.16b
40+
; CHECK-NEXT: ret
41+
%cmp = icmp slt <2 x i64> %mask, zeroinitializer
42+
%cmp.bc = bitcast <2 x i1> %cmp to i2
43+
%cmp.bc.not = icmp eq i2 %cmp.bc, 0
44+
%res = select i1 %cmp.bc.not, <2 x i64> %a, <2 x i64> %b
45+
ret <2 x i64> %res
46+
}
47+
48+
define <32 x i8> @any_of_select_vf32(<32 x i8> %mask, <32 x i8> %a, <32 x i8> %b) {
49+
; CHECK-LABEL: any_of_select_vf32:
50+
; CHECK: // %bb.0:
51+
; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
52+
; CHECK-NEXT: cmlt v0.16b, v0.16b, #0
53+
; CHECK-NEXT: umaxv b0, v0.16b
54+
; CHECK-NEXT: fmov w8, s0
55+
; CHECK-NEXT: tst w8, #0x1
56+
; CHECK-NEXT: csetm w8, ne
57+
; CHECK-NEXT: dup v1.16b, w8
58+
; CHECK-NEXT: mov v0.16b, v1.16b
59+
; CHECK-NEXT: bsl v1.16b, v5.16b, v3.16b
60+
; CHECK-NEXT: bsl v0.16b, v4.16b, v2.16b
61+
; CHECK-NEXT: ret
62+
%cmp = icmp slt <32 x i8> %mask, zeroinitializer
63+
%cmp.bc = bitcast <32 x i1> %cmp to i32
64+
%cmp.bc.not = icmp eq i32 %cmp.bc, 0
65+
%res = select i1 %cmp.bc.not, <32 x i8> %a, <32 x i8> %b
66+
ret <32 x i8> %res
67+
}

0 commit comments

Comments
 (0)