Skip to content

Commit fd330de

Browse files
authored
[DAG] Constant fold ISD::FSHL/FSHR nodes (llvm#154480)
Fixes llvm#153612. This patch handles trinary scalar integers for FSHL/R in `FoldConstantArithmetic`. Pending until llvm#153790 is merged.
1 parent c9e5b6a commit fd330de

File tree

3 files changed

+206
-27
lines changed

3 files changed

+206
-27
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11281,6 +11281,11 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
1128111281
unsigned BitWidth = VT.getScalarSizeInBits();
1128211282
SDLoc DL(N);
1128311283

11284+
// fold (fshl/fshr C0, C1, C2) -> C3
11285+
if (SDValue C =
11286+
DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1, N2}))
11287+
return C;
11288+
1128411289
// fold (fshl N0, N1, 0) -> N0
1128511290
// fold (fshr N0, N1, 0) -> N1
1128611291
if (isPowerOf2_32(BitWidth))

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7175,6 +7175,45 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
71757175
}
71767176
}
71777177

7178+
// Handle fshl/fshr special cases.
7179+
if (Opcode == ISD::FSHL || Opcode == ISD::FSHR) {
7180+
auto *C1 = dyn_cast<ConstantSDNode>(Ops[0]);
7181+
auto *C2 = dyn_cast<ConstantSDNode>(Ops[1]);
7182+
auto *C3 = dyn_cast<ConstantSDNode>(Ops[2]);
7183+
7184+
if (C1 && C2 && C3) {
7185+
if (C1->isOpaque() || C2->isOpaque() || C3->isOpaque())
7186+
return SDValue();
7187+
const APInt &V1 = C1->getAPIntValue(), &V2 = C2->getAPIntValue(),
7188+
&V3 = C3->getAPIntValue();
7189+
7190+
APInt FoldedVal = Opcode == ISD::FSHL ? APIntOps::fshl(V1, V2, V3)
7191+
: APIntOps::fshr(V1, V2, V3);
7192+
return getConstant(FoldedVal, DL, VT);
7193+
}
7194+
}
7195+
7196+
// Handle fma/fmad special cases.
7197+
if (Opcode == ISD::FMA || Opcode == ISD::FMAD) {
7198+
assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
7199+
assert(Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
7200+
Ops[2].getValueType() == VT && "FMA types must match!");
7201+
ConstantFPSDNode *C1 = dyn_cast<ConstantFPSDNode>(Ops[0]);
7202+
ConstantFPSDNode *C2 = dyn_cast<ConstantFPSDNode>(Ops[1]);
7203+
ConstantFPSDNode *C3 = dyn_cast<ConstantFPSDNode>(Ops[2]);
7204+
if (C1 && C2 && C3) {
7205+
APFloat V1 = C1->getValueAPF();
7206+
const APFloat &V2 = C2->getValueAPF();
7207+
const APFloat &V3 = C3->getValueAPF();
7208+
if (Opcode == ISD::FMAD) {
7209+
V1.multiply(V2, APFloat::rmNearestTiesToEven);
7210+
V1.add(V3, APFloat::rmNearestTiesToEven);
7211+
} else
7212+
V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
7213+
return getConstantFP(V1, DL, VT);
7214+
}
7215+
}
7216+
71787217
// This is for vector folding only from here on.
71797218
if (!VT.isVector())
71807219
return SDValue();
@@ -8137,27 +8176,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
81378176
"Operand is DELETED_NODE!");
81388177
// Perform various simplifications.
81398178
switch (Opcode) {
8140-
case ISD::FMA:
8141-
case ISD::FMAD: {
8142-
assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
8143-
assert(N1.getValueType() == VT && N2.getValueType() == VT &&
8144-
N3.getValueType() == VT && "FMA types must match!");
8145-
ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
8146-
ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
8147-
ConstantFPSDNode *N3CFP = dyn_cast<ConstantFPSDNode>(N3);
8148-
if (N1CFP && N2CFP && N3CFP) {
8149-
APFloat V1 = N1CFP->getValueAPF();
8150-
const APFloat &V2 = N2CFP->getValueAPF();
8151-
const APFloat &V3 = N3CFP->getValueAPF();
8152-
if (Opcode == ISD::FMAD) {
8153-
V1.multiply(V2, APFloat::rmNearestTiesToEven);
8154-
V1.add(V3, APFloat::rmNearestTiesToEven);
8155-
} else
8156-
V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
8157-
return getConstantFP(V1, DL, VT);
8158-
}
8159-
break;
8160-
}
81618179
case ISD::BUILD_VECTOR: {
81628180
// Attempt to simplify BUILD_VECTOR.
81638181
SDValue Ops[] = {N1, N2, N3};
@@ -8183,12 +8201,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
81838201
// Use FoldSetCC to simplify SETCC's.
81848202
if (SDValue V = FoldSetCC(VT, N1, N2, cast<CondCodeSDNode>(N3)->get(), DL))
81858203
return V;
8186-
// Vector constant folding.
8187-
SDValue Ops[] = {N1, N2, N3};
8188-
if (SDValue V = FoldConstantArithmetic(Opcode, DL, VT, Ops)) {
8189-
NewSDValueDbgMsg(V, "New node vector constant folding: ", this);
8190-
return V;
8191-
}
81928204
break;
81938205
}
81948206
case ISD::SELECT:
@@ -8324,6 +8336,19 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
83248336
}
83258337
}
83268338

8339+
// Perform trivial constant folding for arithmetic operators.
8340+
switch (Opcode) {
8341+
case ISD::FMA:
8342+
case ISD::FMAD:
8343+
case ISD::SETCC:
8344+
case ISD::FSHL:
8345+
case ISD::FSHR:
8346+
if (SDValue SV =
8347+
FoldConstantArithmetic(Opcode, DL, VT, {N1, N2, N3}, Flags))
8348+
return SV;
8349+
break;
8350+
}
8351+
83278352
// Memoize node if it doesn't produce a glue result.
83288353
SDNode *N;
83298354
SDVTList VTs = getVTList(VT);
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx2 | FileCheck %s --check-prefixes=CHECK,CHECK-EXPAND
3+
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vl,+avx512vbmi2 | FileCheck %s --check-prefixes=CHECK,CHECK-UNEXPAND
4+
5+
define <4 x i32> @test_fshl_constants() {
6+
; CHECK-EXPAND-LABEL: test_fshl_constants:
7+
; CHECK-EXPAND: # %bb.0:
8+
; CHECK-EXPAND-NEXT: vmovaps {{.*#+}} xmm0 = [0,512,2048,6144]
9+
; CHECK-EXPAND-NEXT: retq
10+
;
11+
; CHECK-UNEXPAND-LABEL: test_fshl_constants:
12+
; CHECK-UNEXPAND: # %bb.0:
13+
; CHECK-UNEXPAND-NEXT: vpmovsxwd {{.*#+}} xmm0 = [0,512,2048,6144]
14+
; CHECK-UNEXPAND-NEXT: retq
15+
%res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> <i32 0, i32 1, i32 2, i32 3>, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
16+
ret <4 x i32> %res
17+
}
18+
19+
define <4 x i32> @test_fshl_splat_constants() {
20+
; CHECK-LABEL: test_fshl_splat_constants:
21+
; CHECK: # %bb.0:
22+
; CHECK-NEXT: vbroadcastss {{.*#+}} xmm0 = [256,256,256,256]
23+
; CHECK-NEXT: retq
24+
%res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> <i32 1, i32 1, i32 1, i32 1>, <4 x i32> <i32 4, i32 4, i32 4, i32 4>, <4 x i32> <i32 8, i32 8, i32 8, i32 8>)
25+
ret <4 x i32> %res
26+
}
27+
28+
define <4 x i32> @test_fshl_two_constants(<4 x i32> %a) {
29+
; CHECK-EXPAND-LABEL: test_fshl_two_constants:
30+
; CHECK-EXPAND: # %bb.0:
31+
; CHECK-EXPAND-NEXT: vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
32+
; CHECK-EXPAND-NEXT: retq
33+
;
34+
; CHECK-UNEXPAND-LABEL: test_fshl_two_constants:
35+
; CHECK-UNEXPAND: # %bb.0:
36+
; CHECK-UNEXPAND-NEXT: vpmovsxbd {{.*#+}} xmm1 = [4,5,6,7]
37+
; CHECK-UNEXPAND-NEXT: vpshldvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
38+
; CHECK-UNEXPAND-NEXT: retq
39+
%res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %a, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
40+
ret <4 x i32> %res
41+
}
42+
43+
define <4 x i32> @test_fshl_one_constant(<4 x i32> %a, <4 x i32> %b) {
44+
; CHECK-EXPAND-LABEL: test_fshl_one_constant:
45+
; CHECK-EXPAND: # %bb.0:
46+
; CHECK-EXPAND-NEXT: vpsrlvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
47+
; CHECK-EXPAND-NEXT: vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
48+
; CHECK-EXPAND-NEXT: vpor %xmm1, %xmm0, %xmm0
49+
; CHECK-EXPAND-NEXT: retq
50+
;
51+
; CHECK-UNEXPAND-LABEL: test_fshl_one_constant:
52+
; CHECK-UNEXPAND: # %bb.0:
53+
; CHECK-UNEXPAND-NEXT: vpshldvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
54+
; CHECK-UNEXPAND-NEXT: retq
55+
%res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
56+
ret <4 x i32> %res
57+
}
58+
59+
define <4 x i32> @test_fshl_none_constant(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
60+
; CHECK-EXPAND-LABEL: test_fshl_none_constant:
61+
; CHECK-EXPAND: # %bb.0:
62+
; CHECK-EXPAND-NEXT: vpbroadcastd {{.*#+}} xmm3 = [31,31,31,31]
63+
; CHECK-EXPAND-NEXT: vpandn %xmm3, %xmm2, %xmm4
64+
; CHECK-EXPAND-NEXT: vpsrld $1, %xmm1, %xmm1
65+
; CHECK-EXPAND-NEXT: vpsrlvd %xmm4, %xmm1, %xmm1
66+
; CHECK-EXPAND-NEXT: vpand %xmm3, %xmm2, %xmm2
67+
; CHECK-EXPAND-NEXT: vpsllvd %xmm2, %xmm0, %xmm0
68+
; CHECK-EXPAND-NEXT: vpor %xmm1, %xmm0, %xmm0
69+
; CHECK-EXPAND-NEXT: retq
70+
;
71+
; CHECK-UNEXPAND-LABEL: test_fshl_none_constant:
72+
; CHECK-UNEXPAND: # %bb.0:
73+
; CHECK-UNEXPAND-NEXT: vpshldvd %xmm2, %xmm1, %xmm0
74+
; CHECK-UNEXPAND-NEXT: retq
75+
%res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c)
76+
ret <4 x i32> %res
77+
}
78+
79+
define <4 x i32> @test_fshr_constants() {
80+
; CHECK-LABEL: test_fshr_constants:
81+
; CHECK: # %bb.0:
82+
; CHECK-NEXT: vmovaps {{.*#+}} xmm0 = [0,8388608,8388608,6291456]
83+
; CHECK-NEXT: retq
84+
%res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> <i32 0, i32 1, i32 2, i32 3>, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
85+
ret <4 x i32> %res
86+
}
87+
88+
define <4 x i32> @test_fshr_two_constants(<4 x i32> %a) {
89+
; CHECK-EXPAND-LABEL: test_fshr_two_constants:
90+
; CHECK-EXPAND: # %bb.0:
91+
; CHECK-EXPAND-NEXT: vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
92+
; CHECK-EXPAND-NEXT: retq
93+
;
94+
; CHECK-UNEXPAND-LABEL: test_fshr_two_constants:
95+
; CHECK-UNEXPAND: # %bb.0:
96+
; CHECK-UNEXPAND-NEXT: vpmovsxbd {{.*#+}} xmm1 = [4,5,6,7]
97+
; CHECK-UNEXPAND-NEXT: vpshrdvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
98+
; CHECK-UNEXPAND-NEXT: vmovdqa %xmm1, %xmm0
99+
; CHECK-UNEXPAND-NEXT: retq
100+
%res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> %a, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
101+
ret <4 x i32> %res
102+
}
103+
104+
define <4 x i32> @test_fshr_one_constant(<4 x i32> %a, <4 x i32> %b) {
105+
; CHECK-EXPAND-LABEL: test_fshr_one_constant:
106+
; CHECK-EXPAND: # %bb.0:
107+
; CHECK-EXPAND-NEXT: vpsrlvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
108+
; CHECK-EXPAND-NEXT: vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
109+
; CHECK-EXPAND-NEXT: vpor %xmm1, %xmm0, %xmm0
110+
; CHECK-EXPAND-NEXT: retq
111+
;
112+
; CHECK-UNEXPAND-LABEL: test_fshr_one_constant:
113+
; CHECK-UNEXPAND: # %bb.0:
114+
; CHECK-UNEXPAND-NEXT: vpshrdvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
115+
; CHECK-UNEXPAND-NEXT: vmovdqa %xmm1, %xmm0
116+
; CHECK-UNEXPAND-NEXT: retq
117+
%res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
118+
ret <4 x i32> %res
119+
}
120+
121+
define <4 x i32> @test_fshr_none_constant(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
122+
; CHECK-EXPAND-LABEL: test_fshr_none_constant:
123+
; CHECK-EXPAND: # %bb.0:
124+
; CHECK-EXPAND-NEXT: vpbroadcastd {{.*#+}} xmm3 = [31,31,31,31]
125+
; CHECK-EXPAND-NEXT: vpand %xmm3, %xmm2, %xmm4
126+
; CHECK-EXPAND-NEXT: vpsrlvd %xmm4, %xmm1, %xmm1
127+
; CHECK-EXPAND-NEXT: vpandn %xmm3, %xmm2, %xmm2
128+
; CHECK-EXPAND-NEXT: vpaddd %xmm0, %xmm0, %xmm0
129+
; CHECK-EXPAND-NEXT: vpsllvd %xmm2, %xmm0, %xmm0
130+
; CHECK-EXPAND-NEXT: vpor %xmm1, %xmm0, %xmm0
131+
; CHECK-EXPAND-NEXT: retq
132+
;
133+
; CHECK-UNEXPAND-LABEL: test_fshr_none_constant:
134+
; CHECK-UNEXPAND: # %bb.0:
135+
; CHECK-UNEXPAND-NEXT: vpshrdvd %xmm2, %xmm0, %xmm1
136+
; CHECK-UNEXPAND-NEXT: vmovdqa %xmm1, %xmm0
137+
; CHECK-UNEXPAND-NEXT: retq
138+
%res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c)
139+
ret <4 x i32> %res
140+
}
141+
142+
define <4 x i32> @test_fshr_splat_constants() {
143+
; CHECK-LABEL: test_fshr_splat_constants:
144+
; CHECK: # %bb.0:
145+
; CHECK-NEXT: vbroadcastss {{.*#+}} xmm0 = [16777216,16777216,16777216,16777216]
146+
; CHECK-NEXT: retq
147+
%res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> <i32 1, i32 1, i32 1, i32 1>, <4 x i32> <i32 4, i32 4, i32 4, i32 4>, <4 x i32> <i32 8, i32 8, i32 8, i32 8>)
148+
ret <4 x i32> %res
149+
}

0 commit comments

Comments
 (0)