Skip to content

Commit 403aaee

Browse files
committed
Resolve NITs
1 parent 12b5115 commit 403aaee

File tree

4 files changed

+41
-36
lines changed

4 files changed

+41
-36
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,18 +2464,17 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
24642464
return false;
24652465

24662466
const NVPTXSubtarget *STI = TM.getSubtargetImpl();
2467-
const bool IsNativelySupported =
2468-
STI->getSmVersion() >= 90 && STI->getPTXVersion() >= 78;
2469-
if (IsNativelySupported)
2467+
if (STI->hasNativeBF16Support(N->getOpcode()))
24702468
return false;
24712469

2472-
assert(VT == MVT::bf16 || VT == MVT::v2bf16);
2473-
const bool IsVec = VT == MVT::v2bf16;
2470+
const bool IsVec = VT.isVector();
2471+
assert(!IsVec || VT.getVectorNumElements() == 2);
24742472
SDLoc DL(N);
24752473
SDValue N0 = N->getOperand(0);
24762474
SDValue N1 = N->getOperand(1);
24772475
SmallVector<SDValue, 3> Operands;
24782476
auto GetConstant = [&](float Value) -> SDValue {
2477+
// BF16 immediates must be legalized to integer register values
24792478
APFloat APF(Value);
24802479
bool LosesInfo;
24812480
APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -535,34 +535,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
535535

536536
auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
537537
LegalizeAction NoBF16Action) {
538-
bool IsOpSupported = STI.hasBF16Math();
539-
switch (Op) {
540-
// Several BF16 instructions are available on sm_90 only.
541-
case ISD::FADD:
542-
case ISD::FMUL:
543-
case ISD::FSUB:
544-
case ISD::SELECT:
545-
case ISD::SELECT_CC:
546-
case ISD::SETCC:
547-
case ISD::FEXP2:
548-
case ISD::FCEIL:
549-
case ISD::FFLOOR:
550-
case ISD::FNEARBYINT:
551-
case ISD::FRINT:
552-
case ISD::FROUNDEVEN:
553-
case ISD::FTRUNC:
554-
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
555-
break;
556-
// Several BF16 instructions are available on sm_80 only.
557-
case ISD::FMINNUM:
558-
case ISD::FMAXNUM:
559-
case ISD::FMAXNUM_IEEE:
560-
case ISD::FMINNUM_IEEE:
561-
case ISD::FMAXIMUM:
562-
case ISD::FMINIMUM:
563-
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
564-
break;
565-
}
538+
bool IsOpSupported = STI.hasNativeBF16Support(Op);
566539
setOperationAction(
567540
Op, VT, IsOpSupported ? Action : NoBF16Action);
568541
};
@@ -862,11 +835,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
862835
AddPromotedToType(Op, MVT::bf16, MVT::f32);
863836
}
864837

865-
// Lower bf16 add/mul/sub as fma when it avoids promotion
838+
// On SM80, we select add/mul/sub as fma to avoid promotion to float
866839
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
867840
for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
868-
if (getOperationAction(Op, VT) != Legal &&
869-
getOperationAction(ISD::FMA, VT) == Legal) {
841+
if (!STI.hasNativeBF16Support(Op) && STI.hasNativeBF16Support(ISD::FMA)) {
870842
setOperationAction(Op, VT, Custom);
871843
}
872844
}

llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,38 @@ bool NVPTXSubtarget::allowFP16Math() const {
7070
return hasFP16Math() && NoF16Math == false;
7171
}
7272

73+
bool NVPTXSubtarget::hasNativeBF16Support(int Opcode) const {
74+
if (!hasBF16Math())
75+
return false;
76+
77+
switch (Opcode) {
78+
// Several BF16 instructions are available on sm_90 only.
79+
case ISD::FADD:
80+
case ISD::FMUL:
81+
case ISD::FSUB:
82+
case ISD::SELECT:
83+
case ISD::SELECT_CC:
84+
case ISD::SETCC:
85+
case ISD::FEXP2:
86+
case ISD::FCEIL:
87+
case ISD::FFLOOR:
88+
case ISD::FNEARBYINT:
89+
case ISD::FRINT:
90+
case ISD::FROUNDEVEN:
91+
case ISD::FTRUNC:
92+
return getSmVersion() >= 90 && getPTXVersion() >= 78;
93+
// Several BF16 instructions are available on sm_80 only.
94+
case ISD::FMINNUM:
95+
case ISD::FMAXNUM:
96+
case ISD::FMAXNUM_IEEE:
97+
case ISD::FMINNUM_IEEE:
98+
case ISD::FMAXIMUM:
99+
case ISD::FMINIMUM:
100+
return getSmVersion() >= 80 && getPTXVersion() >= 70;
101+
}
102+
return true;
103+
}
104+
73105
void NVPTXSubtarget::failIfClustersUnsupported(
74106
std::string const &FailureMessage) const {
75107
if (hasClusters())

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
118118
}
119119
bool hasTargetName() const { return !TargetName.empty(); }
120120

121+
bool hasNativeBF16Support(int Opcode) const;
122+
121123
// Get maximum value of required alignments among the supported data types.
122124
// From the PTX ISA doc, section 8.2.3:
123125
// The memory consistency model relates operations executed on memory

0 commit comments

Comments
 (0)