Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "NVPTXISelDAGToDAG.h"
#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/ISDOpcodes.h"
Expand Down Expand Up @@ -191,6 +192,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
}
break;
}
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
if (tryBF16ArithToFMA(N))
return;
break;
default:
break;
}
Expand Down Expand Up @@ -2450,6 +2457,62 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
return true;
}

// Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
EVT VT = SDValue(N, 0).getValueType();
if (VT.getScalarType() != MVT::bf16)
return false;

const NVPTXSubtarget *STI = TM.getSubtargetImpl();
if (STI->hasNativeBF16Support(N->getOpcode()))
return false;

const bool IsVec = VT.isVector();
assert(!IsVec || VT.getVectorNumElements() == 2);
SDLoc DL(N);
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SmallVector<SDValue, 3> Operands;
auto GetConstant = [&](float Value) -> SDValue {
// BF16 immediates must be legalized to integer register values
APFloat APF(Value);
bool LosesInfo;
APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
assert(!LosesInfo);
if (IsVec) {
auto API = APF.bitcastToAPInt();
API = API.concat(API);
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does PTX have a way to specify/use bf16 constant args? It would be nice to avoid passing the constant via a register.
I suspect there's no good way to use a constant for bf16x2, but I would assume that there should be a way to use FP constants for a scalar FMA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No there doesn't seem to be any support for bf16 immediate values, ptxas complains

ptxas /tmp/tmplmhxk1av.ptx, line 79; error   : Arguments mismatch for instruction 'fma'
ptxas fatal   : Ptx assembly aborted due to errors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Please add a comment about that. Otherwise these register moves look rather questionable.

}
auto Const = CurDAG->getTargetConstantFP(APF, DL, VT);
return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16ri, DL, VT, Const), 0);
};

switch (N->getOpcode()) {
case ISD::FADD:
// add(a, b) -> fma(a, 1.0, b)
Operands = {N0, GetConstant(1.0), N1};
break;
case ISD::FSUB:
// sub(a, b) -> fma(b, -1.0, a)
Operands = {N1, GetConstant(-1.0), N0};
break;
case ISD::FMUL:
// mul(a, b) -> fma(a, b, -0.0)
// NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
Operands = {N0, N1, GetConstant(-0.0)};
break;
default:
llvm_unreachable("Unexpected opcode");
};

int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
ReplaceNode(N, FMA);
return true;
}

static inline bool isAddLike(const SDValue V) {
return V.getOpcode() == ISD::ADD ||
(V->getOpcode() == ISD::OR && V->getFlags().hasDisjoint());
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryFence(SDNode *N);
void SelectAddrSpaceCast(SDNode *N);
bool tryBFE(SDNode *N);
bool tryBF16ArithToFMA(SDNode *N);
bool tryConstantFP(SDNode *N);
bool SelectSETP_F16X2(SDNode *N);
bool SelectSETP_BF16X2(SDNode *N);
Expand Down
65 changes: 37 additions & 28 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,34 +535,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,

auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
LegalizeAction NoBF16Action) {
bool IsOpSupported = STI.hasBF16Math();
switch (Op) {
// Several BF16 instructions are available on sm_90 only.
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
case ISD::SELECT:
case ISD::SELECT_CC:
case ISD::SETCC:
case ISD::FEXP2:
case ISD::FCEIL:
case ISD::FFLOOR:
case ISD::FNEARBYINT:
case ISD::FRINT:
case ISD::FROUNDEVEN:
case ISD::FTRUNC:
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
break;
// Several BF16 instructions are available on sm_80 only.
case ISD::FMINNUM:
case ISD::FMAXNUM:
case ISD::FMAXNUM_IEEE:
case ISD::FMINNUM_IEEE:
case ISD::FMAXIMUM:
case ISD::FMINIMUM:
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
break;
}
bool IsOpSupported = STI.hasNativeBF16Support(Op);
setOperationAction(
Op, VT, IsOpSupported ? Action : NoBF16Action);
};
Expand Down Expand Up @@ -862,6 +835,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}

// On SM80, we select add/mul/sub as fma to avoid promotion to float
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
if (!STI.hasNativeBF16Support(Op) && STI.hasNativeBF16Support(ISD::FMA)) {
setOperationAction(Op, VT, Custom);
}
}
}

// f16/f16x2 neg was introduced in PTX 60, SM_53.
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
STI.getPTXVersion() >= 60 &&
Expand Down Expand Up @@ -2498,6 +2480,27 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
}

static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
EVT NVT = MVT::f32;
if (VT.isVector()) {
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
}
SDLoc DL(N);
SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
return DAG.getFPExtendOrRound(Res, DL, VT);
}

SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
SelectionDAG &DAG) const {
if (useF32FTZ(DAG.getMachineFunction())) {
return PromoteBinOpToF32(Op.getNode(), DAG);
}
return Op;
}

SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
SelectionDAG &DAG) const {
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
Expand Down Expand Up @@ -2689,6 +2692,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerSTACKSAVE(Op, DAG);
case ISD::CopyToReg:
return LowerCopyToReg_128(Op, DAG);
case ISD::FADD:
case ISD::FSUB:
case ISD::FMUL:
// Used only for bf16 on SM80, where we select fma for non-ftz operation
return PromoteBinOpIfF32FTZ(Op, DAG);

default:
llvm_unreachable("Custom lowering not defined for operation");
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;

SDValue PromoteBinOpIfF32FTZ(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;

Expand Down
32 changes: 32 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,38 @@ bool NVPTXSubtarget::allowFP16Math() const {
return hasFP16Math() && NoF16Math == false;
}

bool NVPTXSubtarget::hasNativeBF16Support(int Opcode) const {
if (!hasBF16Math())
return false;

switch (Opcode) {
// Several BF16 instructions are available on sm_90 only.
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
case ISD::SELECT:
case ISD::SELECT_CC:
case ISD::SETCC:
case ISD::FEXP2:
case ISD::FCEIL:
case ISD::FFLOOR:
case ISD::FNEARBYINT:
case ISD::FRINT:
case ISD::FROUNDEVEN:
case ISD::FTRUNC:
return getSmVersion() >= 90 && getPTXVersion() >= 78;
// Several BF16 instructions are available on sm_80 only.
case ISD::FMINNUM:
case ISD::FMAXNUM:
case ISD::FMAXNUM_IEEE:
case ISD::FMINNUM_IEEE:
case ISD::FMAXIMUM:
case ISD::FMINIMUM:
return getSmVersion() >= 80 && getPTXVersion() >= 70;
}
return true;
}

void NVPTXSubtarget::failIfClustersUnsupported(
std::string const &FailureMessage) const {
if (hasClusters())
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
}
bool hasTargetName() const { return !TargetName.empty(); }

bool hasNativeBF16Support(int Opcode) const;

// Get maximum value of required alignments among the supported data types.
// From the PTX ISA doc, section 8.2.3:
// The memory consistency model relates operations executed on memory
Expand Down
56 changes: 25 additions & 31 deletions llvm/test/CodeGen/NVPTX/atomics-sm90.ll
Original file line number Diff line number Diff line change
Expand Up @@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
; CHECKPTX71-LABEL: test(
; CHECKPTX71: {
; CHECKPTX71-NEXT: .reg .pred %p<5>;
; CHECKPTX71-NEXT: .reg .b16 %rs<22>;
; CHECKPTX71-NEXT: .reg .b16 %rs<26>;
; CHECKPTX71-NEXT: .reg .b32 %r<4>;
; CHECKPTX71-NEXT: .reg .f32 %f<12>;
; CHECKPTX71-EMPTY:
; CHECKPTX71-NEXT: // %bb.0:
; CHECKPTX71-NEXT: ld.param.b16 %rs13, [test_param_3];
; CHECKPTX71-NEXT: ld.param.u32 %r3, [test_param_2];
; CHECKPTX71-NEXT: ld.param.u32 %r2, [test_param_1];
; CHECKPTX71-NEXT: ld.param.u32 %r1, [test_param_0];
; CHECKPTX71-NEXT: ld.b16 %rs18, [%r1];
; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs13;
; CHECKPTX71-NEXT: ld.b16 %rs22, [%r1];
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start14
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs18;
; CHECKPTX71-NEXT: add.rn.f32 %f3, %f2, %f1;
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs14, %f3;
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs18;
; CHECKPTX71-NEXT: mov.u16 %rs18, %rs3;
; CHECKPTX71-NEXT: mov.b16 %rs14, 0x3F80;
; CHECKPTX71-NEXT: fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs22;
; CHECKPTX71-NEXT: mov.u16 %rs22, %rs3;
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end13
; CHECKPTX71-NEXT: ld.b16 %rs19, [%r1];
; CHECKPTX71-NEXT: ld.b16 %rs23, [%r1];
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start8
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs19;
; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs15, %f5;
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs19;
; CHECKPTX71-NEXT: mov.u16 %rs19, %rs6;
; CHECKPTX71-NEXT: mov.b16 %rs16, 0x3F80;
; CHECKPTX71-NEXT: fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs23;
; CHECKPTX71-NEXT: mov.u16 %rs23, %rs6;
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end7
; CHECKPTX71-NEXT: ld.global.b16 %rs20, [%r2];
; CHECKPTX71-NEXT: ld.global.b16 %rs24, [%r2];
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start2
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
; CHECKPTX71-NEXT: cvt.f32.bf16 %f7, %rs20;
; CHECKPTX71-NEXT: add.rn.f32 %f8, %f7, %f1;
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f8;
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs20;
; CHECKPTX71-NEXT: mov.u16 %rs20, %rs9;
; CHECKPTX71-NEXT: mov.b16 %rs18, 0x3F80;
; CHECKPTX71-NEXT: fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs24;
; CHECKPTX71-NEXT: mov.u16 %rs24, %rs9;
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end1
; CHECKPTX71-NEXT: ld.shared.b16 %rs21, [%r3];
; CHECKPTX71-NEXT: ld.shared.b16 %rs25, [%r3];
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
; CHECKPTX71-NEXT: cvt.f32.bf16 %f10, %rs21;
; CHECKPTX71-NEXT: add.rn.f32 %f11, %f10, %f1;
; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs17, %f11;
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs21;
; CHECKPTX71-NEXT: mov.u16 %rs21, %rs12;
; CHECKPTX71-NEXT: mov.b16 %rs20, 0x3F80;
; CHECKPTX71-NEXT: fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs25;
; CHECKPTX71-NEXT: mov.u16 %rs25, %rs12;
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end
; CHECKPTX71-NEXT: ret;
Expand Down
Loading
Loading