Skip to content

Commit f145f25

Browse files
committed
Move fma expansions to default expand rule
1 parent 59b06e7 commit f145f25

File tree

4 files changed

+88
-25
lines changed

4 files changed

+88
-25
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5325,6 +5325,23 @@ class TargetLowering : public TargetLoweringBase {
53255325
SDNodeFlags Flags, const SDLoc &DL,
53265326
SelectionDAG &DAG) const;
53275327

5328+
/// Expand floating point add
5329+
/// \param N Node to expand
5330+
/// \returns The expansion result or SDValue() if it fails.
5331+
SDValue expandFADD(SDNode *N, SelectionDAG &DAG) const;
5332+
5333+
/// Expand floating point multiply
5334+
/// \param N Node to expand
5335+
/// \param Result output after conversion
5336+
/// \returns The expansion result or SDValue() if it fails.
5337+
SDValue expandFMUL(SDNode *N, SelectionDAG &DAG) const;
5338+
5339+
/// Expand floating point subtract
5340+
/// \param N Node to expand
5341+
/// \param Result output after conversion
5342+
/// \returns The expansion result or SDValue() if it fails.
5343+
SDValue expandFSUB(SDNode *N, SelectionDAG &DAG) const;
5344+
53285345
/// Expand CTPOP nodes. Expands vector/scalar CTPOP nodes,
53295346
/// vector nodes can only succeed if all operations are legal/custom.
53305347
/// \param N Node to expand

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3671,14 +3671,21 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
36713671
Results.push_back(ExpandConstant(CP));
36723672
break;
36733673
}
3674+
case ISD::FADD: {
3675+
if (SDValue Expand = TLI.expandFADD(Node, DAG)) {
3676+
Results.push_back(Expand);
3677+
}
3678+
break;
3679+
}
3680+
case ISD::FMUL: {
3681+
if (SDValue Expand = TLI.expandFMUL(Node, DAG)) {
3682+
Results.push_back(Expand);
3683+
}
3684+
break;
3685+
}
36743686
case ISD::FSUB: {
3675-
EVT VT = Node->getValueType(0);
3676-
if (TLI.isOperationLegalOrCustom(ISD::FADD, VT) &&
3677-
TLI.isOperationLegalOrCustom(ISD::FNEG, VT)) {
3678-
const SDNodeFlags Flags = Node->getFlags();
3679-
Tmp1 = DAG.getNode(ISD::FNEG, dl, VT, Node->getOperand(1));
3680-
Tmp1 = DAG.getNode(ISD::FADD, dl, VT, Node->getOperand(0), Tmp1, Flags);
3681-
Results.push_back(Tmp1);
3687+
if (SDValue Expand = TLI.expandFSUB(Node, DAG)) {
3688+
Results.push_back(Expand);
36823689
}
36833690
break;
36843691
}

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9070,6 +9070,60 @@ SDValue TargetLowering::expandIS_FPCLASS(EVT ResultVT, SDValue Op,
90709070
return Res;
90719071
}
90729072

9073+
SDValue TargetLowering::expandFADD(SDNode *Node, SelectionDAG &DAG) const {
9074+
auto VT = Node->getValueType(0);
9075+
if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
9076+
return {};
9077+
}
9078+
9079+
// FADD(a, b) -> FMA(a, 1.0, b)
9080+
SDLoc DL(Node);
9081+
auto One = DAG.getConstantFP(1.0, DL, VT);
9082+
SmallVector<SDValue, 3> Operands{Node->getOperand(0), One,
9083+
Node->getOperand(1)};
9084+
return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
9085+
}
9086+
9087+
SDValue TargetLowering::expandFMUL(SDNode *Node, SelectionDAG &DAG) const {
9088+
auto VT = Node->getValueType(0);
9089+
if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
9090+
return {};
9091+
}
9092+
9093+
// FMUL(a, b) -> FMA(a, b, -0.0)
9094+
// NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
9095+
SDLoc DL(Node);
9096+
auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
9097+
SmallVector<SDValue, 3> Operands{Node->getOperand(0), Node->getOperand(1),
9098+
NegZero};
9099+
return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
9100+
}
9101+
9102+
SDValue TargetLowering::expandFSUB(SDNode *Node, SelectionDAG &DAG) const {
9103+
SDLoc DL(Node);
9104+
SDNodeFlags SDFlags = Node->getFlags();
9105+
auto VT = Node->getValueType(0);
9106+
9107+
bool CanUseFMA = isOperationLegalOrCustom(ISD::FMA, VT);
9108+
bool CanUseAddSub = (isOperationLegalOrCustom(ISD::FADD, VT) &&
9109+
isOperationLegalOrCustom(ISD::FNEG, VT));
9110+
bool PreferAddSub = CanUseAddSub && isFNegFree(VT);
9111+
9112+
// FSUB(a, b) -> FMA(b, -1.0, a)
9113+
if (CanUseFMA && !PreferAddSub) {
9114+
auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
9115+
SmallVector<SDValue, 3> Operands{Node->getOperand(1), NegOne,
9116+
Node->getOperand(0)};
9117+
return DAG.getNode(ISD::FMA, DL, VT, Operands, SDFlags);
9118+
}
9119+
// FSUB(a, b) -> FADD(a, FNEG(b))
9120+
if (CanUseAddSub) {
9121+
auto Neg = DAG.getNode(ISD::FNEG, DL, VT, Node->getOperand(1));
9122+
return DAG.getNode(ISD::FADD, DL, VT, Node->getOperand(0), Neg, SDFlags);
9123+
}
9124+
return {};
9125+
}
9126+
90739127
// Only expand vector types if we have the appropriate vector bit operations.
90749128
static bool canExpandVectorCTPOP(const TargetLowering &TLI, EVT VT) {
90759129
assert(VT.isVector() && "Expected vector type");

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,11 +2520,7 @@ SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
25202520
}
25212521

25222522
// FADD(a, b) -> FMA(a, 1.0, b)
2523-
SDLoc DL(Op);
2524-
auto VT = Op.getValueType();
2525-
auto One = DAG.getConstantFP(1.0, DL, VT);
2526-
SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
2527-
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2523+
return expandFADD(Op.getNode(), DAG);
25282524
}
25292525

25302526
SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
@@ -2534,12 +2530,7 @@ SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
25342530
}
25352531

25362532
// FSUB(a, b) -> FMA(b, -1.0, a)
2537-
SDLoc DL(Op);
2538-
auto VT = Op.getValueType();
2539-
auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
2540-
SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
2541-
Op->getOperand(0)};
2542-
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2533+
return expandFSUB(Op.getNode(), DAG);
25432534
}
25442535

25452536
SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
@@ -2549,13 +2540,7 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
25492540
}
25502541

25512542
// FMUL(a, b) -> FMA(a, b, -0.0)
2552-
// NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
2553-
SDLoc DL(Op);
2554-
auto VT = Op.getValueType();
2555-
auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
2556-
SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1),
2557-
NegZero};
2558-
return DAG.getNode(ISD::FMA, DL, VT, Operands);
2543+
return expandFMUL(Op.getNode(), DAG);
25592544
}
25602545

25612546
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,

0 commit comments

Comments
 (0)