Skip to content

Commit c25293c

Browse files
authored
[LegalizeVectorOps][RISCV] Don't promote VP_FABS/FNEG/FCOPYSIGN. (llvm#106659)
Promoting canonicalizes NaNs which changes the semantics. Bitcast to integer and use logic ops instead.
1 parent 688843b commit c25293c

File tree

7 files changed

+336
-484
lines changed

7 files changed

+336
-484
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class VectorLegalizer {
135135
SDValue ExpandVP_SELECT(SDNode *Node);
136136
SDValue ExpandVP_MERGE(SDNode *Node);
137137
SDValue ExpandVP_REM(SDNode *Node);
138+
SDValue ExpandVP_FNEG(SDNode *Node);
139+
SDValue ExpandVP_FABS(SDNode *Node);
140+
SDValue ExpandVP_FCOPYSIGN(SDNode *Node);
138141
SDValue ExpandSELECT(SDNode *Node);
139142
std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
140143
SDValue ExpandStore(SDNode *N);
@@ -699,6 +702,11 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
699702
// These operations are used to do promotion so they can't be promoted
700703
// themselves.
701704
llvm_unreachable("Don't know how to promote this operation!");
705+
case ISD::VP_FABS:
706+
case ISD::VP_FCOPYSIGN:
707+
case ISD::VP_FNEG:
708+
// Promoting fabs, fneg, and fcopysign changes their semantics.
709+
llvm_unreachable("These operations should not be promoted");
702710
}
703711

704712
// There are currently two cases of vector promotion:
@@ -887,6 +895,24 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
887895
return;
888896
}
889897
break;
898+
case ISD::VP_FNEG:
899+
if (SDValue Expanded = ExpandVP_FNEG(Node)) {
900+
Results.push_back(Expanded);
901+
return;
902+
}
903+
break;
904+
case ISD::VP_FABS:
905+
if (SDValue Expanded = ExpandVP_FABS(Node)) {
906+
Results.push_back(Expanded);
907+
return;
908+
}
909+
break;
910+
case ISD::VP_FCOPYSIGN:
911+
if (SDValue Expanded = ExpandVP_FCOPYSIGN(Node)) {
912+
Results.push_back(Expanded);
913+
return;
914+
}
915+
break;
890916
case ISD::SELECT:
891917
Results.push_back(ExpandSELECT(Node));
892918
return;
@@ -1557,6 +1583,80 @@ SDValue VectorLegalizer::ExpandVP_REM(SDNode *Node) {
15571583
return DAG.getNode(ISD::VP_SUB, DL, VT, Dividend, Mul, Mask, EVL);
15581584
}
15591585

1586+
SDValue VectorLegalizer::ExpandVP_FNEG(SDNode *Node) {
1587+
EVT VT = Node->getValueType(0);
1588+
EVT IntVT = VT.changeVectorElementTypeToInteger();
1589+
1590+
if (!TLI.isOperationLegalOrCustom(ISD::VP_XOR, IntVT))
1591+
return SDValue();
1592+
1593+
SDValue Mask = Node->getOperand(1);
1594+
SDValue EVL = Node->getOperand(2);
1595+
1596+
SDLoc DL(Node);
1597+
SDValue Cast = DAG.getNode(ISD::BITCAST, DL, IntVT, Node->getOperand(0));
1598+
SDValue SignMask = DAG.getConstant(
1599+
APInt::getSignMask(IntVT.getScalarSizeInBits()), DL, IntVT);
1600+
SDValue Xor = DAG.getNode(ISD::VP_XOR, DL, IntVT, Cast, SignMask, Mask, EVL);
1601+
return DAG.getNode(ISD::BITCAST, DL, VT, Xor);
1602+
}
1603+
1604+
SDValue VectorLegalizer::ExpandVP_FABS(SDNode *Node) {
1605+
EVT VT = Node->getValueType(0);
1606+
EVT IntVT = VT.changeVectorElementTypeToInteger();
1607+
1608+
if (!TLI.isOperationLegalOrCustom(ISD::VP_AND, IntVT))
1609+
return SDValue();
1610+
1611+
SDValue Mask = Node->getOperand(1);
1612+
SDValue EVL = Node->getOperand(2);
1613+
1614+
SDLoc DL(Node);
1615+
SDValue Cast = DAG.getNode(ISD::BITCAST, DL, IntVT, Node->getOperand(0));
1616+
SDValue ClearSignMask = DAG.getConstant(
1617+
APInt::getSignedMaxValue(IntVT.getScalarSizeInBits()), DL, IntVT);
1618+
SDValue ClearSign =
1619+
DAG.getNode(ISD::VP_AND, DL, IntVT, Cast, ClearSignMask, Mask, EVL);
1620+
return DAG.getNode(ISD::BITCAST, DL, VT, ClearSign);
1621+
}
1622+
1623+
SDValue VectorLegalizer::ExpandVP_FCOPYSIGN(SDNode *Node) {
1624+
EVT VT = Node->getValueType(0);
1625+
1626+
if (VT != Node->getOperand(1).getValueType())
1627+
return SDValue();
1628+
1629+
EVT IntVT = VT.changeVectorElementTypeToInteger();
1630+
if (!TLI.isOperationLegalOrCustom(ISD::VP_AND, IntVT) ||
1631+
!TLI.isOperationLegalOrCustom(ISD::VP_XOR, IntVT))
1632+
return SDValue();
1633+
1634+
SDValue Mask = Node->getOperand(2);
1635+
SDValue EVL = Node->getOperand(3);
1636+
1637+
SDLoc DL(Node);
1638+
SDValue Mag = DAG.getNode(ISD::BITCAST, DL, IntVT, Node->getOperand(0));
1639+
SDValue Sign = DAG.getNode(ISD::BITCAST, DL, IntVT, Node->getOperand(1));
1640+
1641+
SDValue SignMask = DAG.getConstant(
1642+
APInt::getSignMask(IntVT.getScalarSizeInBits()), DL, IntVT);
1643+
SDValue SignBit =
1644+
DAG.getNode(ISD::VP_AND, DL, IntVT, Sign, SignMask, Mask, EVL);
1645+
1646+
SDValue ClearSignMask = DAG.getConstant(
1647+
APInt::getSignedMaxValue(IntVT.getScalarSizeInBits()), DL, IntVT);
1648+
SDValue ClearedSign =
1649+
DAG.getNode(ISD::VP_AND, DL, IntVT, Mag, ClearSignMask, Mask, EVL);
1650+
1651+
SDNodeFlags Flags;
1652+
Flags.setDisjoint(true);
1653+
1654+
SDValue CopiedSign = DAG.getNode(ISD::VP_OR, DL, IntVT, ClearedSign, SignBit,
1655+
Mask, EVL, Flags);
1656+
1657+
return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign);
1658+
}
1659+
15601660
void VectorLegalizer::ExpandFP_TO_UINT(SDNode *Node,
15611661
SmallVectorImpl<SDValue> &Results) {
15621662
// Attempt to expand using TargetLowering.

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -891,16 +891,30 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
891891
ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA};
892892

893893
// TODO: support more vp ops.
894-
static const unsigned ZvfhminPromoteVPOps[] = {
895-
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
896-
ISD::VP_FDIV, ISD::VP_FNEG, ISD::VP_FABS,
897-
ISD::VP_FMA, ISD::VP_REDUCE_FADD, ISD::VP_REDUCE_SEQ_FADD,
898-
ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_SQRT,
899-
ISD::VP_FMINNUM, ISD::VP_FMAXNUM, ISD::VP_FCEIL,
900-
ISD::VP_FFLOOR, ISD::VP_FROUND, ISD::VP_FROUNDEVEN,
901-
ISD::VP_FCOPYSIGN, ISD::VP_FROUNDTOZERO, ISD::VP_FRINT,
902-
ISD::VP_FNEARBYINT, ISD::VP_SETCC, ISD::VP_FMINIMUM,
903-
ISD::VP_FMAXIMUM, ISD::VP_REDUCE_FMINIMUM, ISD::VP_REDUCE_FMAXIMUM};
894+
static const unsigned ZvfhminPromoteVPOps[] = {ISD::VP_FADD,
895+
ISD::VP_FSUB,
896+
ISD::VP_FMUL,
897+
ISD::VP_FDIV,
898+
ISD::VP_FMA,
899+
ISD::VP_REDUCE_FADD,
900+
ISD::VP_REDUCE_SEQ_FADD,
901+
ISD::VP_REDUCE_FMIN,
902+
ISD::VP_REDUCE_FMAX,
903+
ISD::VP_SQRT,
904+
ISD::VP_FMINNUM,
905+
ISD::VP_FMAXNUM,
906+
ISD::VP_FCEIL,
907+
ISD::VP_FFLOOR,
908+
ISD::VP_FROUND,
909+
ISD::VP_FROUNDEVEN,
910+
ISD::VP_FROUNDTOZERO,
911+
ISD::VP_FRINT,
912+
ISD::VP_FNEARBYINT,
913+
ISD::VP_SETCC,
914+
ISD::VP_FMINIMUM,
915+
ISD::VP_FMAXIMUM,
916+
ISD::VP_REDUCE_FMINIMUM,
917+
ISD::VP_REDUCE_FMAXIMUM};
904918

905919
// Sets common operation actions on RVV floating-point vector types.
906920
const auto SetCommonVFPActions = [&](MVT VT) {

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfabs-vp.ll

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ define <2 x half> @vfabs_vv_v2f16(<2 x half> %va, <2 x i1> %m, i32 zeroext %evl)
1919
;
2020
; ZVFHMIN-LABEL: vfabs_vv_v2f16:
2121
; ZVFHMIN: # %bb.0:
22-
; ZVFHMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
23-
; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v8
24-
; ZVFHMIN-NEXT: vsetvli zero, a0, e32, mf2, ta, ma
25-
; ZVFHMIN-NEXT: vfabs.v v9, v9, v0.t
26-
; ZVFHMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
27-
; ZVFHMIN-NEXT: vfncvt.f.f.w v8, v9
22+
; ZVFHMIN-NEXT: lui a1, 8
23+
; ZVFHMIN-NEXT: addi a1, a1, -1
24+
; ZVFHMIN-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
25+
; ZVFHMIN-NEXT: vand.vx v8, v8, a1, v0.t
2826
; ZVFHMIN-NEXT: ret
2927
%v = call <2 x half> @llvm.vp.fabs.v2f16(<2 x half> %va, <2 x i1> %m, i32 %evl)
3028
ret <2 x half> %v
@@ -39,12 +37,10 @@ define <2 x half> @vfabs_vv_v2f16_unmasked(<2 x half> %va, i32 zeroext %evl) {
3937
;
4038
; ZVFHMIN-LABEL: vfabs_vv_v2f16_unmasked:
4139
; ZVFHMIN: # %bb.0:
42-
; ZVFHMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
43-
; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v8
44-
; ZVFHMIN-NEXT: vsetvli zero, a0, e32, mf2, ta, ma
45-
; ZVFHMIN-NEXT: vfabs.v v9, v9
46-
; ZVFHMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
47-
; ZVFHMIN-NEXT: vfncvt.f.f.w v8, v9
40+
; ZVFHMIN-NEXT: lui a1, 8
41+
; ZVFHMIN-NEXT: addi a1, a1, -1
42+
; ZVFHMIN-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
43+
; ZVFHMIN-NEXT: vand.vx v8, v8, a1
4844
; ZVFHMIN-NEXT: ret
4945
%v = call <2 x half> @llvm.vp.fabs.v2f16(<2 x half> %va, <2 x i1> splat (i1 true), i32 %evl)
5046
ret <2 x half> %v
@@ -61,12 +57,10 @@ define <4 x half> @vfabs_vv_v4f16(<4 x half> %va, <4 x i1> %m, i32 zeroext %evl)
6157
;
6258
; ZVFHMIN-LABEL: vfabs_vv_v4f16:
6359
; ZVFHMIN: # %bb.0:
64-
; ZVFHMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
65-
; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v8
66-
; ZVFHMIN-NEXT: vsetvli zero, a0, e32, m1, ta, ma
67-
; ZVFHMIN-NEXT: vfabs.v v9, v9, v0.t
68-
; ZVFHMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
69-
; ZVFHMIN-NEXT: vfncvt.f.f.w v8, v9
60+
; ZVFHMIN-NEXT: lui a1, 8
61+
; ZVFHMIN-NEXT: addi a1, a1, -1
62+
; ZVFHMIN-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
63+
; ZVFHMIN-NEXT: vand.vx v8, v8, a1, v0.t
7064
; ZVFHMIN-NEXT: ret
7165
%v = call <4 x half> @llvm.vp.fabs.v4f16(<4 x half> %va, <4 x i1> %m, i32 %evl)
7266
ret <4 x half> %v
@@ -81,12 +75,10 @@ define <4 x half> @vfabs_vv_v4f16_unmasked(<4 x half> %va, i32 zeroext %evl) {
8175
;
8276
; ZVFHMIN-LABEL: vfabs_vv_v4f16_unmasked:
8377
; ZVFHMIN: # %bb.0:
84-
; ZVFHMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
85-
; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v8
86-
; ZVFHMIN-NEXT: vsetvli zero, a0, e32, m1, ta, ma
87-
; ZVFHMIN-NEXT: vfabs.v v9, v9
88-
; ZVFHMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
89-
; ZVFHMIN-NEXT: vfncvt.f.f.w v8, v9
78+
; ZVFHMIN-NEXT: lui a1, 8
79+
; ZVFHMIN-NEXT: addi a1, a1, -1
80+
; ZVFHMIN-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
81+
; ZVFHMIN-NEXT: vand.vx v8, v8, a1
9082
; ZVFHMIN-NEXT: ret
9183
%v = call <4 x half> @llvm.vp.fabs.v4f16(<4 x half> %va, <4 x i1> splat (i1 true), i32 %evl)
9284
ret <4 x half> %v
@@ -103,12 +95,10 @@ define <8 x half> @vfabs_vv_v8f16(<8 x half> %va, <8 x i1> %m, i32 zeroext %evl)
10395
;
10496
; ZVFHMIN-LABEL: vfabs_vv_v8f16:
10597
; ZVFHMIN: # %bb.0:
106-
; ZVFHMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
107-
; ZVFHMIN-NEXT: vfwcvt.f.f.v v10, v8
108-
; ZVFHMIN-NEXT: vsetvli zero, a0, e32, m2, ta, ma
109-
; ZVFHMIN-NEXT: vfabs.v v10, v10, v0.t
110-
; ZVFHMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
111-
; ZVFHMIN-NEXT: vfncvt.f.f.w v8, v10
98+
; ZVFHMIN-NEXT: lui a1, 8
99+
; ZVFHMIN-NEXT: addi a1, a1, -1
100+
; ZVFHMIN-NEXT: vsetvli zero, a0, e16, m1, ta, ma
101+
; ZVFHMIN-NEXT: vand.vx v8, v8, a1, v0.t
112102
; ZVFHMIN-NEXT: ret
113103
%v = call <8 x half> @llvm.vp.fabs.v8f16(<8 x half> %va, <8 x i1> %m, i32 %evl)
114104
ret <8 x half> %v
@@ -123,12 +113,10 @@ define <8 x half> @vfabs_vv_v8f16_unmasked(<8 x half> %va, i32 zeroext %evl) {
123113
;
124114
; ZVFHMIN-LABEL: vfabs_vv_v8f16_unmasked:
125115
; ZVFHMIN: # %bb.0:
126-
; ZVFHMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
127-
; ZVFHMIN-NEXT: vfwcvt.f.f.v v10, v8
128-
; ZVFHMIN-NEXT: vsetvli zero, a0, e32, m2, ta, ma
129-
; ZVFHMIN-NEXT: vfabs.v v10, v10
130-
; ZVFHMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
131-
; ZVFHMIN-NEXT: vfncvt.f.f.w v8, v10
116+
; ZVFHMIN-NEXT: lui a1, 8
117+
; ZVFHMIN-NEXT: addi a1, a1, -1
118+
; ZVFHMIN-NEXT: vsetvli zero, a0, e16, m1, ta, ma
119+
; ZVFHMIN-NEXT: vand.vx v8, v8, a1
132120
; ZVFHMIN-NEXT: ret
133121
%v = call <8 x half> @llvm.vp.fabs.v8f16(<8 x half> %va, <8 x i1> splat (i1 true), i32 %evl)
134122
ret <8 x half> %v
@@ -145,12 +133,10 @@ define <16 x half> @vfabs_vv_v16f16(<16 x half> %va, <16 x i1> %m, i32 zeroext %
145133
;
146134
; ZVFHMIN-LABEL: vfabs_vv_v16f16:
147135
; ZVFHMIN: # %bb.0:
148-
; ZVFHMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
149-
; ZVFHMIN-NEXT: vfwcvt.f.f.v v12, v8
150-
; ZVFHMIN-NEXT: vsetvli zero, a0, e32, m4, ta, ma
151-
; ZVFHMIN-NEXT: vfabs.v v12, v12, v0.t
152-
; ZVFHMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
153-
; ZVFHMIN-NEXT: vfncvt.f.f.w v8, v12
136+
; ZVFHMIN-NEXT: lui a1, 8
137+
; ZVFHMIN-NEXT: addi a1, a1, -1
138+
; ZVFHMIN-NEXT: vsetvli zero, a0, e16, m2, ta, ma
139+
; ZVFHMIN-NEXT: vand.vx v8, v8, a1, v0.t
154140
; ZVFHMIN-NEXT: ret
155141
%v = call <16 x half> @llvm.vp.fabs.v16f16(<16 x half> %va, <16 x i1> %m, i32 %evl)
156142
ret <16 x half> %v
@@ -165,12 +151,10 @@ define <16 x half> @vfabs_vv_v16f16_unmasked(<16 x half> %va, i32 zeroext %evl)
165151
;
166152
; ZVFHMIN-LABEL: vfabs_vv_v16f16_unmasked:
167153
; ZVFHMIN: # %bb.0:
168-
; ZVFHMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
169-
; ZVFHMIN-NEXT: vfwcvt.f.f.v v12, v8
170-
; ZVFHMIN-NEXT: vsetvli zero, a0, e32, m4, ta, ma
171-
; ZVFHMIN-NEXT: vfabs.v v12, v12
172-
; ZVFHMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
173-
; ZVFHMIN-NEXT: vfncvt.f.f.w v8, v12
154+
; ZVFHMIN-NEXT: lui a1, 8
155+
; ZVFHMIN-NEXT: addi a1, a1, -1
156+
; ZVFHMIN-NEXT: vsetvli zero, a0, e16, m2, ta, ma
157+
; ZVFHMIN-NEXT: vand.vx v8, v8, a1
174158
; ZVFHMIN-NEXT: ret
175159
%v = call <16 x half> @llvm.vp.fabs.v16f16(<16 x half> %va, <16 x i1> splat (i1 true), i32 %evl)
176160
ret <16 x half> %v

0 commit comments

Comments
 (0)