Skip to content

Commit 5e5fd0e

Browse files
authored
[NVPTX] Select bfloat16 add/mul/sub as fma on SM80 (#121065)
SM80 has fma for bfloat16 but not add/mul/sub. Currently these ops incur a promotion to f32, but we can avoid this by writing them in terms of the fma: ``` FADD(a, b) -> FMA(a, 1.0, b) FMUL(a, b) -> FMA(a, b, -0.0) FSUB(a, b) -> FMA(b, -1.0, a) ``` Unfortunately there is no `fma.ftz` so when ftz is enabled, we still fall back to promotion.
1 parent 9033e0c commit 5e5fd0e

12 files changed

+276
-398
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "NVPTXISelDAGToDAG.h"
14+
#include "NVPTX.h"
1415
#include "NVPTXUtilities.h"
1516
#include "llvm/Analysis/ValueTracking.h"
1617
#include "llvm/CodeGen/ISDOpcodes.h"
@@ -191,6 +192,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
191192
}
192193
break;
193194
}
195+
case ISD::FADD:
196+
case ISD::FMUL:
197+
case ISD::FSUB:
198+
if (tryBF16ArithToFMA(N))
199+
return;
200+
break;
194201
default:
195202
break;
196203
}
@@ -2450,6 +2457,62 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
24502457
return true;
24512458
}
24522459

2460+
// Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
2461+
bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
2462+
EVT VT = SDValue(N, 0).getValueType();
2463+
if (VT.getScalarType() != MVT::bf16)
2464+
return false;
2465+
2466+
const NVPTXSubtarget *STI = TM.getSubtargetImpl();
2467+
if (STI->hasNativeBF16Support(N->getOpcode()))
2468+
return false;
2469+
2470+
const bool IsVec = VT.isVector();
2471+
assert(!IsVec || VT.getVectorNumElements() == 2);
2472+
SDLoc DL(N);
2473+
SDValue N0 = N->getOperand(0);
2474+
SDValue N1 = N->getOperand(1);
2475+
SmallVector<SDValue, 3> Operands;
2476+
auto GetConstant = [&](float Value) -> SDValue {
2477+
// BF16 immediates must be legalized to integer register values
2478+
APFloat APF(Value);
2479+
bool LosesInfo;
2480+
APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
2481+
assert(!LosesInfo);
2482+
if (IsVec) {
2483+
auto API = APF.bitcastToAPInt();
2484+
API = API.concat(API);
2485+
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
2486+
return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
2487+
}
2488+
auto Const = CurDAG->getTargetConstantFP(APF, DL, VT);
2489+
return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16ri, DL, VT, Const), 0);
2490+
};
2491+
2492+
switch (N->getOpcode()) {
2493+
case ISD::FADD:
2494+
// add(a, b) -> fma(a, 1.0, b)
2495+
Operands = {N0, GetConstant(1.0), N1};
2496+
break;
2497+
case ISD::FSUB:
2498+
// sub(a, b) -> fma(b, -1.0, a)
2499+
Operands = {N1, GetConstant(-1.0), N0};
2500+
break;
2501+
case ISD::FMUL:
2502+
// mul(a, b) -> fma(a, b, -0.0)
2503+
// NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
2504+
Operands = {N0, N1, GetConstant(-0.0)};
2505+
break;
2506+
default:
2507+
llvm_unreachable("Unexpected opcode");
2508+
};
2509+
2510+
int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
2511+
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
2512+
ReplaceNode(N, FMA);
2513+
return true;
2514+
}
2515+
24532516
static inline bool isAddLike(const SDValue V) {
24542517
return V.getOpcode() == ISD::ADD ||
24552518
(V->getOpcode() == ISD::OR && V->getFlags().hasDisjoint());

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
8484
bool tryFence(SDNode *N);
8585
void SelectAddrSpaceCast(SDNode *N);
8686
bool tryBFE(SDNode *N);
87+
bool tryBF16ArithToFMA(SDNode *N);
8788
bool tryConstantFP(SDNode *N);
8889
bool SelectSETP_F16X2(SDNode *N);
8990
bool SelectSETP_BF16X2(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -535,34 +535,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
535535

536536
auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
537537
LegalizeAction NoBF16Action) {
538-
bool IsOpSupported = STI.hasBF16Math();
539-
switch (Op) {
540-
// Several BF16 instructions are available on sm_90 only.
541-
case ISD::FADD:
542-
case ISD::FMUL:
543-
case ISD::FSUB:
544-
case ISD::SELECT:
545-
case ISD::SELECT_CC:
546-
case ISD::SETCC:
547-
case ISD::FEXP2:
548-
case ISD::FCEIL:
549-
case ISD::FFLOOR:
550-
case ISD::FNEARBYINT:
551-
case ISD::FRINT:
552-
case ISD::FROUNDEVEN:
553-
case ISD::FTRUNC:
554-
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
555-
break;
556-
// Several BF16 instructions are available on sm_80 only.
557-
case ISD::FMINNUM:
558-
case ISD::FMAXNUM:
559-
case ISD::FMAXNUM_IEEE:
560-
case ISD::FMINNUM_IEEE:
561-
case ISD::FMAXIMUM:
562-
case ISD::FMINIMUM:
563-
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
564-
break;
565-
}
538+
bool IsOpSupported = STI.hasNativeBF16Support(Op);
566539
setOperationAction(
567540
Op, VT, IsOpSupported ? Action : NoBF16Action);
568541
};
@@ -862,6 +835,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
862835
AddPromotedToType(Op, MVT::bf16, MVT::f32);
863836
}
864837

838+
// On SM80, we select add/mul/sub as fma to avoid promotion to float
839+
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
840+
for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
841+
if (!STI.hasNativeBF16Support(Op) && STI.hasNativeBF16Support(ISD::FMA)) {
842+
setOperationAction(Op, VT, Custom);
843+
}
844+
}
845+
}
846+
865847
// f16/f16x2 neg was introduced in PTX 60, SM_53.
866848
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
867849
STI.getPTXVersion() >= 60 &&
@@ -2498,6 +2480,27 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
24982480
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
24992481
}
25002482

2483+
static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2484+
EVT VT = N->getValueType(0);
2485+
EVT NVT = MVT::f32;
2486+
if (VT.isVector()) {
2487+
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2488+
}
2489+
SDLoc DL(N);
2490+
SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
2491+
SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
2492+
SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
2493+
return DAG.getFPExtendOrRound(Res, DL, VT);
2494+
}
2495+
2496+
SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2497+
SelectionDAG &DAG) const {
2498+
if (useF32FTZ(DAG.getMachineFunction())) {
2499+
return PromoteBinOpToF32(Op.getNode(), DAG);
2500+
}
2501+
return Op;
2502+
}
2503+
25012504
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
25022505
SelectionDAG &DAG) const {
25032506
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2689,6 +2692,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
26892692
return LowerSTACKSAVE(Op, DAG);
26902693
case ISD::CopyToReg:
26912694
return LowerCopyToReg_128(Op, DAG);
2695+
case ISD::FADD:
2696+
case ISD::FSUB:
2697+
case ISD::FMUL:
2698+
// Used only for bf16 on SM80, where we select fma for non-ftz operation
2699+
return PromoteBinOpIfF32FTZ(Op, DAG);
2700+
26922701
default:
26932702
llvm_unreachable("Custom lowering not defined for operation");
26942703
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ class NVPTXTargetLowering : public TargetLowering {
278278
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
279279
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
280280

281+
SDValue PromoteBinOpIfF32FTZ(SDValue Op, SelectionDAG &DAG) const;
282+
281283
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
282284
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
283285

llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,38 @@ bool NVPTXSubtarget::allowFP16Math() const {
7070
return hasFP16Math() && NoF16Math == false;
7171
}
7272

73+
bool NVPTXSubtarget::hasNativeBF16Support(int Opcode) const {
74+
if (!hasBF16Math())
75+
return false;
76+
77+
switch (Opcode) {
78+
// Several BF16 instructions are available on sm_90 only.
79+
case ISD::FADD:
80+
case ISD::FMUL:
81+
case ISD::FSUB:
82+
case ISD::SELECT:
83+
case ISD::SELECT_CC:
84+
case ISD::SETCC:
85+
case ISD::FEXP2:
86+
case ISD::FCEIL:
87+
case ISD::FFLOOR:
88+
case ISD::FNEARBYINT:
89+
case ISD::FRINT:
90+
case ISD::FROUNDEVEN:
91+
case ISD::FTRUNC:
92+
return getSmVersion() >= 90 && getPTXVersion() >= 78;
93+
// Several BF16 instructions are available on sm_80 only.
94+
case ISD::FMINNUM:
95+
case ISD::FMAXNUM:
96+
case ISD::FMAXNUM_IEEE:
97+
case ISD::FMINNUM_IEEE:
98+
case ISD::FMAXIMUM:
99+
case ISD::FMINIMUM:
100+
return getSmVersion() >= 80 && getPTXVersion() >= 70;
101+
}
102+
return true;
103+
}
104+
73105
void NVPTXSubtarget::failIfClustersUnsupported(
74106
std::string const &FailureMessage) const {
75107
if (hasClusters())

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
118118
}
119119
bool hasTargetName() const { return !TargetName.empty(); }
120120

121+
bool hasNativeBF16Support(int Opcode) const;
122+
121123
// Get maximum value of required alignments among the supported data types.
122124
// From the PTX ISA doc, section 8.2.3:
123125
// The memory consistency model relates operations executed on memory

llvm/test/CodeGen/NVPTX/atomics-sm90.ll

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
4646
; CHECKPTX71-LABEL: test(
4747
; CHECKPTX71: {
4848
; CHECKPTX71-NEXT: .reg .pred %p<5>;
49-
; CHECKPTX71-NEXT: .reg .b16 %rs<22>;
49+
; CHECKPTX71-NEXT: .reg .b16 %rs<26>;
5050
; CHECKPTX71-NEXT: .reg .b32 %r<4>;
51-
; CHECKPTX71-NEXT: .reg .f32 %f<12>;
5251
; CHECKPTX71-EMPTY:
5352
; CHECKPTX71-NEXT: // %bb.0:
5453
; CHECKPTX71-NEXT: ld.param.b16 %rs13, [test_param_3];
5554
; CHECKPTX71-NEXT: ld.param.u32 %r3, [test_param_2];
5655
; CHECKPTX71-NEXT: ld.param.u32 %r2, [test_param_1];
5756
; CHECKPTX71-NEXT: ld.param.u32 %r1, [test_param_0];
58-
; CHECKPTX71-NEXT: ld.b16 %rs18, [%r1];
59-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs13;
57+
; CHECKPTX71-NEXT: ld.b16 %rs22, [%r1];
6058
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start14
6159
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
62-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs18;
63-
; CHECKPTX71-NEXT: add.rn.f32 %f3, %f2, %f1;
64-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs14, %f3;
65-
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
66-
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs18;
67-
; CHECKPTX71-NEXT: mov.u16 %rs18, %rs3;
60+
; CHECKPTX71-NEXT: mov.b16 %rs14, 0x3F80;
61+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
62+
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
63+
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs22;
64+
; CHECKPTX71-NEXT: mov.u16 %rs22, %rs3;
6865
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
6966
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end13
70-
; CHECKPTX71-NEXT: ld.b16 %rs19, [%r1];
67+
; CHECKPTX71-NEXT: ld.b16 %rs23, [%r1];
7168
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start8
7269
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
73-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs19;
74-
; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
75-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs15, %f5;
76-
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
77-
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs19;
78-
; CHECKPTX71-NEXT: mov.u16 %rs19, %rs6;
70+
; CHECKPTX71-NEXT: mov.b16 %rs16, 0x3F80;
71+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
72+
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
73+
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs23;
74+
; CHECKPTX71-NEXT: mov.u16 %rs23, %rs6;
7975
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
8076
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end7
81-
; CHECKPTX71-NEXT: ld.global.b16 %rs20, [%r2];
77+
; CHECKPTX71-NEXT: ld.global.b16 %rs24, [%r2];
8278
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start2
8379
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
84-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f7, %rs20;
85-
; CHECKPTX71-NEXT: add.rn.f32 %f8, %f7, %f1;
86-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f8;
87-
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
88-
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs20;
89-
; CHECKPTX71-NEXT: mov.u16 %rs20, %rs9;
80+
; CHECKPTX71-NEXT: mov.b16 %rs18, 0x3F80;
81+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
82+
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
83+
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs24;
84+
; CHECKPTX71-NEXT: mov.u16 %rs24, %rs9;
9085
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
9186
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end1
92-
; CHECKPTX71-NEXT: ld.shared.b16 %rs21, [%r3];
87+
; CHECKPTX71-NEXT: ld.shared.b16 %rs25, [%r3];
9388
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start
9489
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
95-
; CHECKPTX71-NEXT: cvt.f32.bf16 %f10, %rs21;
96-
; CHECKPTX71-NEXT: add.rn.f32 %f11, %f10, %f1;
97-
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs17, %f11;
98-
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
99-
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs21;
100-
; CHECKPTX71-NEXT: mov.u16 %rs21, %rs12;
90+
; CHECKPTX71-NEXT: mov.b16 %rs20, 0x3F80;
91+
; CHECKPTX71-NEXT: fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
92+
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
93+
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs25;
94+
; CHECKPTX71-NEXT: mov.u16 %rs25, %rs12;
10195
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
10296
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end
10397
; CHECKPTX71-NEXT: ret;

0 commit comments

Comments
 (0)