Skip to content

Commit f634361

Browse files
committed
Move fma to selection stage
1 parent f8eb2e0 commit f634361

File tree

12 files changed

+131
-143
lines changed

12 files changed

+131
-143
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5325,23 +5325,6 @@ class TargetLowering : public TargetLoweringBase {
53255325
SDNodeFlags Flags, const SDLoc &DL,
53265326
SelectionDAG &DAG) const;
53275327

5328-
/// Expand floating point add
5329-
/// \param N Node to expand
5330-
/// \returns The expansion result or SDValue() if it fails.
5331-
SDValue expandFADD(SDNode *N, SelectionDAG &DAG) const;
5332-
5333-
/// Expand floating point multiply
5334-
/// \param N Node to expand
5335-
/// \param Result output after conversion
5336-
/// \returns The expansion result or SDValue() if it fails.
5337-
SDValue expandFMUL(SDNode *N, SelectionDAG &DAG) const;
5338-
5339-
/// Expand floating point subtract
5340-
/// \param N Node to expand
5341-
/// \param Result output after conversion
5342-
/// \returns The expansion result or SDValue() if it fails.
5343-
SDValue expandFSUB(SDNode *N, SelectionDAG &DAG) const;
5344-
53455328
/// Expand CTPOP nodes. Expands vector/scalar CTPOP nodes,
53465329
/// vector nodes can only succeed if all operations are legal/custom.
53475330
/// \param N Node to expand

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17551,13 +17551,10 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1755117551
return N2;
1755217552
}
1755317553

17554-
const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
17555-
!TLI.isOperationLegal(ISD::FADD, VT));
17556-
1755717554
// FIXME: Support splat of constant.
17558-
if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
17555+
if (N0CFP && N0CFP->isExactlyValue(1.0))
1755917556
return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
17560-
if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
17557+
if (N1CFP && N1CFP->isExactlyValue(1.0))
1756117558
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
1756217559

1756317560
// Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17589,7 +17586,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1758917586

1759017587
// (fma x, -1, y) -> (fadd (fneg x), y)
1759117588
// FIXME: Support splat of constant.
17592-
if (N1CFP && !PreferFMAAdd) {
17589+
if (N1CFP) {
1759317590
if (N1CFP->isExactlyValue(1.0))
1759417591
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
1759517592

@@ -17599,14 +17596,15 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
1759917596
AddToWorklist(RHSNeg.getNode());
1760017597
return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
1760117598
}
17602-
}
17603-
// fma (fneg x), K, y -> fma x -K, y
17604-
if (N1CFP && matcher.match(N0, ISD::FNEG) &&
17605-
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17606-
(N1.hasOneUse() &&
17607-
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17608-
return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17609-
matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
17599+
17600+
// fma (fneg x), K, y -> fma x -K, y
17601+
if (matcher.match(N0, ISD::FNEG) &&
17602+
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17603+
(N1.hasOneUse() &&
17604+
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17605+
return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17606+
matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
17607+
}
1761017608
}
1761117609

1761217610
// FIXME: Support splat of constant.

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3671,21 +3671,14 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
36713671
Results.push_back(ExpandConstant(CP));
36723672
break;
36733673
}
3674-
case ISD::FADD: {
3675-
if (SDValue Expand = TLI.expandFADD(Node, DAG)) {
3676-
Results.push_back(Expand);
3677-
}
3678-
break;
3679-
}
3680-
case ISD::FMUL: {
3681-
if (SDValue Expand = TLI.expandFMUL(Node, DAG)) {
3682-
Results.push_back(Expand);
3683-
}
3684-
break;
3685-
}
36863674
case ISD::FSUB: {
3687-
if (SDValue Expand = TLI.expandFSUB(Node, DAG)) {
3688-
Results.push_back(Expand);
3675+
EVT VT = Node->getValueType(0);
3676+
if (TLI.isOperationLegalOrCustom(ISD::FADD, VT) &&
3677+
TLI.isOperationLegalOrCustom(ISD::FNEG, VT)) {
3678+
const SDNodeFlags Flags = Node->getFlags();
3679+
Tmp1 = DAG.getNode(ISD::FNEG, dl, VT, Node->getOperand(1));
3680+
Tmp1 = DAG.getNode(ISD::FADD, dl, VT, Node->getOperand(0), Tmp1, Flags);
3681+
Results.push_back(Tmp1);
36893682
}
36903683
break;
36913684
}

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9068,60 +9068,6 @@ SDValue TargetLowering::expandIS_FPCLASS(EVT ResultVT, SDValue Op,
90689068
return Res;
90699069
}
90709070

9071-
SDValue TargetLowering::expandFADD(SDNode *Node, SelectionDAG &DAG) const {
9072-
auto VT = Node->getValueType(0);
9073-
if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
9074-
return {};
9075-
}
9076-
9077-
// FADD(a, b) -> FMA(a, 1.0, b)
9078-
SDLoc DL(Node);
9079-
auto One = DAG.getConstantFP(1.0, DL, VT);
9080-
SmallVector<SDValue, 3> Operands{Node->getOperand(0), One,
9081-
Node->getOperand(1)};
9082-
return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
9083-
}
9084-
9085-
SDValue TargetLowering::expandFMUL(SDNode *Node, SelectionDAG &DAG) const {
9086-
auto VT = Node->getValueType(0);
9087-
if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
9088-
return {};
9089-
}
9090-
9091-
// FMUL(a, b) -> FMA(a, b, -0.0)
9092-
// NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
9093-
SDLoc DL(Node);
9094-
auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
9095-
SmallVector<SDValue, 3> Operands{Node->getOperand(0), Node->getOperand(1),
9096-
NegZero};
9097-
return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
9098-
}
9099-
9100-
SDValue TargetLowering::expandFSUB(SDNode *Node, SelectionDAG &DAG) const {
9101-
SDLoc DL(Node);
9102-
SDNodeFlags SDFlags = Node->getFlags();
9103-
auto VT = Node->getValueType(0);
9104-
9105-
bool CanUseFMA = isOperationLegalOrCustom(ISD::FMA, VT);
9106-
bool CanUseAddSub = (isOperationLegalOrCustom(ISD::FADD, VT) &&
9107-
isOperationLegalOrCustom(ISD::FNEG, VT));
9108-
bool PreferAddSub = CanUseAddSub && isFNegFree(VT);
9109-
9110-
// FSUB(a, b) -> FMA(b, -1.0, a)
9111-
if (CanUseFMA && !PreferAddSub) {
9112-
auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
9113-
SmallVector<SDValue, 3> Operands{Node->getOperand(1), NegOne,
9114-
Node->getOperand(0)};
9115-
return DAG.getNode(ISD::FMA, DL, VT, Operands, SDFlags);
9116-
}
9117-
// FSUB(a, b) -> FADD(a, FNEG(b))
9118-
if (CanUseAddSub) {
9119-
auto Neg = DAG.getNode(ISD::FNEG, DL, VT, Node->getOperand(1));
9120-
return DAG.getNode(ISD::FADD, DL, VT, Node->getOperand(0), Neg, SDFlags);
9121-
}
9122-
return {};
9123-
}
9124-
91259071
// Only expand vector types if we have the appropriate vector bit operations.
91269072
static bool canExpandVectorCTPOP(const TargetLowering &TLI, EVT VT) {
91279073
assert(VT.isVector() && "Expected vector type");

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "NVPTXISelDAGToDAG.h"
14+
#include "NVPTX.h"
1415
#include "NVPTXUtilities.h"
1516
#include "llvm/Analysis/ValueTracking.h"
1617
#include "llvm/CodeGen/ISDOpcodes.h"
@@ -190,6 +191,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
190191
return;
191192
}
192193
break;
194+
case ISD::FADD:
195+
case ISD::FMUL:
196+
case ISD::FSUB:
197+
if (tryBF16ArithToFMA(N))
198+
return;
199+
break;
193200
}
194201
default:
195202
break;
@@ -2450,6 +2457,66 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
24502457
return true;
24512458
}
24522459

2460+
// Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
2461+
bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
2462+
EVT VT = SDValue(N, 0).getValueType();
2463+
if (VT.getScalarType() != MVT::bf16)
2464+
return false;
2465+
2466+
const NVPTXSubtarget *STI = TM.getSubtargetImpl();
2467+
const bool IsNativelySupported =
2468+
STI->getSmVersion() >= 90 && STI->getPTXVersion() >= 78;
2469+
if (IsNativelySupported)
2470+
return false;
2471+
2472+
assert(VT == MVT::bf16 || VT == MVT::v2bf16);
2473+
const bool IsVec = VT == MVT::v2bf16;
2474+
SDLoc DL(N);
2475+
SDValue N0 = N->getOperand(0);
2476+
SDValue N1 = N->getOperand(1);
2477+
SmallVector<SDValue, 3> Operands;
2478+
auto GetConstant = [&](float Value) -> SDValue {
2479+
APFloat APF(Value);
2480+
bool LosesInfo;
2481+
APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
2482+
assert(!LosesInfo);
2483+
if (IsVec) {
2484+
auto API = APF.bitcastToAPInt();
2485+
API = API.concat(API);
2486+
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
2487+
return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
2488+
}
2489+
auto Const = CurDAG->getTargetConstantFP(APF, DL, VT);
2490+
return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16ri, DL, VT, Const), 0);
2491+
};
2492+
2493+
switch (N->getOpcode()) {
2494+
case ISD::FADD: {
2495+
// add(a, b) -> fma(a, 1.0, b)
2496+
Operands = {N0, GetConstant(1.0), N1};
2497+
break;
2498+
}
2499+
case ISD::FSUB: {
2500+
// sub(a, b) -> fma(b, -1.0, a)
2501+
Operands = {N1, GetConstant(-1.0), N0};
2502+
break;
2503+
}
2504+
case ISD::FMUL: {
2505+
// mul(a, b) -> fma(a, b, -0.0)
2506+
// NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
2507+
Operands = {N0, N1, GetConstant(-0.0)};
2508+
break;
2509+
}
2510+
default:
2511+
llvm_unreachable("Unexpected opcode");
2512+
};
2513+
2514+
int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
2515+
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
2516+
ReplaceNode(N, FMA);
2517+
return true;
2518+
}
2519+
24532520
static inline bool isAddLike(const SDValue V) {
24542521
return V.getOpcode() == ISD::ADD ||
24552522
(V->getOpcode() == ISD::OR && V->getFlags().hasDisjoint());

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
8484
bool tryFence(SDNode *N);
8585
void SelectAddrSpaceCast(SDNode *N);
8686
bool tryBFE(SDNode *N);
87+
bool tryBF16ArithToFMA(SDNode *N);
8788
bool tryConstantFP(SDNode *N);
8889
bool SelectSETP_F16X2(SDNode *N);
8990
bool SelectSETP_BF16X2(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,8 +2519,8 @@ SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
25192519
return PromoteBinOpToF32(Op.getNode(), DAG);
25202520
}
25212521

2522-
// FADD(a, b) -> FMA(a, 1.0, b)
2523-
return expandFADD(Op.getNode(), DAG);
2522+
// Legal
2523+
return Op;
25242524
}
25252525

25262526
SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
@@ -2529,8 +2529,8 @@ SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
25292529
return PromoteBinOpToF32(Op.getNode(), DAG);
25302530
}
25312531

2532-
// FSUB(a, b) -> FMA(b, -1.0, a)
2533-
return expandFSUB(Op.getNode(), DAG);
2532+
// Legal
2533+
return Op;
25342534
}
25352535

25362536
SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
@@ -2539,8 +2539,8 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
25392539
return PromoteBinOpToF32(Op.getNode(), DAG);
25402540
}
25412541

2542-
// FMUL(a, b) -> FMA(a, b, -0.0)
2543-
return expandFMUL(Op.getNode(), DAG);
2542+
// Legal
2543+
return Op;
25442544
}
25452545

25462546
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
114114
; SM80-EMPTY:
115115
; SM80-NEXT: // %bb.0:
116116
; SM80-NEXT: ld.param.b16 %rs1, [test_fsub_param_0];
117-
; SM80-NEXT: ld.param.b16 %rs2, [test_fsub_param_1];
118-
; SM80-NEXT: mov.b16 %rs3, 0xBF80;
119-
; SM80-NEXT: fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
117+
; SM80-NEXT: mov.b16 %rs2, 0xBF80;
118+
; SM80-NEXT: ld.param.b16 %rs3, [test_fsub_param_1];
119+
; SM80-NEXT: fma.rn.bf16 %rs4, %rs3, %rs2, %rs1;
120120
; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
121121
; SM80-NEXT: ret;
122122
;

llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
2626
; SM80-EMPTY:
2727
; SM80-NEXT: // %bb.0:
2828
; SM80-NEXT: ld.param.b32 %r1, [test_fadd_imm_0_param_0];
29-
; SM80-NEXT: mov.b32 %r2, 1073758080;
30-
; SM80-NEXT: mov.b32 %r3, 1065369472;
31-
; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r3, %r2;
29+
; SM80-NEXT: mov.b32 %r2, 1065369472;
30+
; SM80-NEXT: mov.b32 %r3, 1073758080;
31+
; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r2, %r3;
3232
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
3333
; SM80-NEXT: ret;
3434
;

llvm/test/CodeGen/NVPTX/fma-relu-contract.ll

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,10 @@ define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat
361361
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
362362
; CHECK-NEXT: mov.b16 %rs5, 0x0000;
363363
; CHECK-NEXT: max.bf16 %rs6, %rs4, %rs5;
364-
; CHECK-NEXT: mov.b16 %rs7, 0x40E0;
365-
; CHECK-NEXT: mov.b16 %rs8, 0x3F80;
366-
; CHECK-NEXT: fma.rn.bf16 %rs9, %rs4, %rs8, %rs7;
367-
; CHECK-NEXT: fma.rn.bf16 %rs10, %rs6, %rs8, %rs9;
364+
; CHECK-NEXT: mov.b16 %rs7, 0x3F80;
365+
; CHECK-NEXT: mov.b16 %rs8, 0x40E0;
366+
; CHECK-NEXT: fma.rn.bf16 %rs9, %rs4, %rs7, %rs8;
367+
; CHECK-NEXT: fma.rn.bf16 %rs10, %rs6, %rs7, %rs9;
368368
; CHECK-NEXT: st.param.b16 [func_retval0], %rs10;
369369
; CHECK-NEXT: ret;
370370
;
@@ -957,10 +957,10 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
957957
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
958958
; CHECK-NEXT: mov.b32 %r5, 0;
959959
; CHECK-NEXT: max.bf16x2 %r6, %r4, %r5;
960-
; CHECK-NEXT: mov.b32 %r7, 1088438496;
961-
; CHECK-NEXT: mov.b32 %r8, 1065369472;
962-
; CHECK-NEXT: fma.rn.bf16x2 %r9, %r4, %r8, %r7;
963-
; CHECK-NEXT: fma.rn.bf16x2 %r10, %r6, %r8, %r9;
960+
; CHECK-NEXT: mov.b32 %r7, 1065369472;
961+
; CHECK-NEXT: mov.b32 %r8, 1088438496;
962+
; CHECK-NEXT: fma.rn.bf16x2 %r9, %r4, %r7, %r8;
963+
; CHECK-NEXT: fma.rn.bf16x2 %r10, %r6, %r7, %r9;
964964
; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
965965
; CHECK-NEXT: ret;
966966
;

0 commit comments

Comments
 (0)