Skip to content

Commit 0e8a414

Browse files
committed
[CUDA, NVPTX] Added basic __bf16 support for NVPTX.
Recent Clang changes expose _bf16 types for SSE2-enabled host compilations and that makes those types visible furing GPU-side compilation, where it currently fails with Sema complaining that __bf16 is not supported. Considering that __bf16 is a storage-only type, enabling it for NVPTX if it's enabled on the host should pose no issues, correctness-wise. Recent NVIDIA GPUs have introduced bf16 support, so we'll likely grow better support for __bf16 on NVPTX going forward. Differential Revision: https://reviews.llvm.org/D136311
1 parent fd5a2bf commit 0e8a414

File tree

11 files changed

+303
-131
lines changed

11 files changed

+303
-131
lines changed

clang/lib/Basic/Targets/NVPTX.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ NVPTXTargetInfo::NVPTXTargetInfo(const llvm::Triple &Triple,
5252
VLASupported = false;
5353
AddrSpaceMap = &NVPTXAddrSpaceMap;
5454
UseAddrSpaceMapMangling = true;
55+
// __bf16 is always available as a load/store only type.
56+
BFloat16Width = BFloat16Align = 16;
57+
BFloat16Format = &llvm::APFloat::BFloat();
5558

5659
// Define available target features
5760
// These must be defined in sorted order!

clang/lib/Basic/Targets/NVPTX.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
177177
}
178178

179179
bool hasBitIntType() const override { return true; }
180+
bool hasBFloat16Type() const override { return true; }
181+
const char *getBFloat16Mangling() const override { return "u6__bf16"; };
180182
};
181183
} // namespace targets
182184
} // namespace clang

clang/test/CodeGenCUDA/bf16.cu

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// REQUIRES: nvptx-registered-target
2+
// REQUIRES: x86-registered-target
3+
4+
// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "nvptx64-nvidia-cuda" \
5+
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -S -o - %s | FileCheck %s
6+
7+
#include "Inputs/cuda.h"
8+
9+
// CHECK-LABEL: .visible .func _Z8test_argPu6__bf16u6__bf16(
10+
// CHECK: .param .b64 _Z8test_argPu6__bf16u6__bf16_param_0,
11+
// CHECK: .param .b16 _Z8test_argPu6__bf16u6__bf16_param_1
12+
//
13+
__device__ void test_arg(__bf16 *out, __bf16 in) {
14+
// CHECK: ld.param.b16 %{{h.*}}, [_Z8test_argPu6__bf16u6__bf16_param_1];
15+
__bf16 bf16 = in;
16+
*out = bf16;
17+
// CHECK: st.b16
18+
// CHECK: ret;
19+
}
20+
21+
22+
// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retu6__bf16(
23+
// CHECK: .param .b16 _Z8test_retu6__bf16_param_0
24+
__device__ __bf16 test_ret( __bf16 in) {
25+
// CHECK: ld.param.b16 %h{{.*}}, [_Z8test_retu6__bf16_param_0];
26+
return in;
27+
// CHECK: st.param.b16 [func_retval0+0], %h
28+
// CHECK: ret;
29+
}
30+
31+
// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callu6__bf16(
32+
// CHECK: .param .b16 _Z9test_callu6__bf16_param_0
33+
__device__ __bf16 test_call( __bf16 in) {
34+
// CHECK: ld.param.b16 %h{{.*}}, [_Z9test_callu6__bf16_param_0];
35+
// CHECK: st.param.b16 [param0+0], %h2;
36+
// CHECK: .param .b32 retval0;
37+
// CHECK: call.uni (retval0),
38+
// CHECK-NEXT: _Z8test_retu6__bf16,
39+
// CHECK-NEXT: (
40+
// CHECK-NEXT: param0
41+
// CHECK-NEXT );
42+
// CHECK: ld.param.b16 %h{{.*}}, [retval0+0];
43+
return test_ret(in);
44+
// CHECK: st.param.b16 [func_retval0+0], %h
45+
// CHECK: ret;
46+
}

clang/test/SemaCUDA/bf16.cu

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// REQUIRES: nvptx-registered-target
2+
// REQUIRES: x86-registered-target
3+
4+
// RUN: %clang_cc1 "-triple" "x86_64-unknown-linux-gnu" "-aux-triple" "nvptx64-nvidia-cuda" \
5+
// RUN: "-target-cpu" "x86-64" -fsyntax-only -verify=scalar %s
6+
// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "nvptx64-nvidia-cuda" \
7+
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -fsyntax-only -verify=scalar %s
8+
9+
#include "Inputs/cuda.h"
10+
11+
__device__ void test(bool b, __bf16 *out, __bf16 in) {
12+
__bf16 bf16 = in; // No error on using the type itself.
13+
14+
bf16 + bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
15+
bf16 - bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
16+
bf16 * bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
17+
bf16 / bf16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
18+
19+
__fp16 fp16;
20+
21+
bf16 + fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
22+
fp16 + bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
23+
bf16 - fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
24+
fp16 - bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
25+
bf16 * fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
26+
fp16 * bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
27+
bf16 / fp16; // scalar-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
28+
fp16 / bf16; // scalar-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
29+
bf16 = fp16; // scalar-error {{assigning to '__bf16' from incompatible type '__fp16'}}
30+
fp16 = bf16; // scalar-error {{assigning to '__fp16' from incompatible type '__bf16'}}
31+
bf16 + (b ? fp16 : bf16); // scalar-error {{incompatible operand types ('__fp16' and '__bf16')}}
32+
*out = bf16;
33+
}

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,6 +1831,7 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
18311831
break;
18321832

18331833
case Type::HalfTyID:
1834+
case Type::BFloatTyID:
18341835
case Type::FloatTyID:
18351836
case Type::DoubleTyID:
18361837
AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -823,8 +823,10 @@ static Optional<unsigned> pickOpcodeForVT(
823823
case MVT::i64:
824824
return Opcode_i64;
825825
case MVT::f16:
826+
case MVT::bf16:
826827
return Opcode_f16;
827828
case MVT::v2f16:
829+
case MVT::v2bf16:
828830
return Opcode_f16x2;
829831
case MVT::f32:
830832
return Opcode_f32;
@@ -835,6 +837,21 @@ static Optional<unsigned> pickOpcodeForVT(
835837
}
836838
}
837839

840+
static int getLdStRegType(EVT VT) {
841+
if (VT.isFloatingPoint())
842+
switch (VT.getSimpleVT().SimpleTy) {
843+
case MVT::f16:
844+
case MVT::bf16:
845+
case MVT::v2f16:
846+
case MVT::v2bf16:
847+
return NVPTX::PTXLdStInstCode::Untyped;
848+
default:
849+
return NVPTX::PTXLdStInstCode::Float;
850+
}
851+
else
852+
return NVPTX::PTXLdStInstCode::Unsigned;
853+
}
854+
838855
bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
839856
SDLoc dl(N);
840857
MemSDNode *LD = cast<MemSDNode>(N);
@@ -891,19 +908,16 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
891908
// Vector Setting
892909
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
893910
if (SimpleVT.isVector()) {
894-
assert(LoadedVT == MVT::v2f16 && "Unexpected vector type");
895-
// v2f16 is loaded using ld.b32
911+
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
912+
"Unexpected vector type");
913+
// v2f16/v2bf16 is loaded using ld.b32
896914
fromTypeWidth = 32;
897915
}
898916

899917
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
900918
fromType = NVPTX::PTXLdStInstCode::Signed;
901-
else if (ScalarVT.isFloatingPoint())
902-
// f16 uses .b16 as its storage type.
903-
fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
904-
: NVPTX::PTXLdStInstCode::Float;
905919
else
906-
fromType = NVPTX::PTXLdStInstCode::Unsigned;
920+
fromType = getLdStRegType(ScalarVT);
907921

908922
// Create the machine instruction DAG
909923
SDValue Chain = N->getOperand(0);
@@ -1033,11 +1047,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10331047
N->getOperand(N->getNumOperands() - 1))->getZExtValue();
10341048
if (ExtensionType == ISD::SEXTLOAD)
10351049
FromType = NVPTX::PTXLdStInstCode::Signed;
1036-
else if (ScalarVT.isFloatingPoint())
1037-
FromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
1038-
: NVPTX::PTXLdStInstCode::Float;
10391050
else
1040-
FromType = NVPTX::PTXLdStInstCode::Unsigned;
1051+
FromType = getLdStRegType(ScalarVT);
10411052

10421053
unsigned VecType;
10431054

@@ -1057,7 +1068,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10571068
// v8f16 is a special case. PTX doesn't have ld.v8.f16
10581069
// instruction. Instead, we split the vector into v2f16 chunks and
10591070
// load them with ld.v4.b32.
1060-
if (EltVT == MVT::v2f16) {
1071+
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
10611072
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
10621073
EltVT = MVT::i32;
10631074
FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1745,18 +1756,13 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
17451756
MVT ScalarVT = SimpleVT.getScalarType();
17461757
unsigned toTypeWidth = ScalarVT.getSizeInBits();
17471758
if (SimpleVT.isVector()) {
1748-
assert(StoreVT == MVT::v2f16 && "Unexpected vector type");
1759+
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1760+
"Unexpected vector type");
17491761
// v2f16 is stored using st.b32
17501762
toTypeWidth = 32;
17511763
}
17521764

1753-
unsigned int toType;
1754-
if (ScalarVT.isFloatingPoint())
1755-
// f16 uses .b16 as its storage type.
1756-
toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
1757-
: NVPTX::PTXLdStInstCode::Float;
1758-
else
1759-
toType = NVPTX::PTXLdStInstCode::Unsigned;
1765+
unsigned int toType = getLdStRegType(ScalarVT);
17601766

17611767
// Create the machine instruction DAG
17621768
SDValue Chain = ST->getChain();
@@ -1896,12 +1902,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
18961902
assert(StoreVT.isSimple() && "Store value is not simple");
18971903
MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
18981904
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1899-
unsigned ToType;
1900-
if (ScalarVT.isFloatingPoint())
1901-
ToType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
1902-
: NVPTX::PTXLdStInstCode::Float;
1903-
else
1904-
ToType = NVPTX::PTXLdStInstCode::Unsigned;
1905+
unsigned ToType = getLdStRegType(ScalarVT);
19051906

19061907
SmallVector<SDValue, 12> StOps;
19071908
SDValue N2;
@@ -1929,7 +1930,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
19291930
// v8f16 is a special case. PTX doesn't have st.v8.f16
19301931
// instruction. Instead, we split the vector into v2f16 chunks and
19311932
// store them with st.v4.b32.
1932-
if (EltVT == MVT::v2f16) {
1933+
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
19331934
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
19341935
EltVT = MVT::i32;
19351936
ToType = NVPTX::PTXLdStInstCode::Untyped;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ static bool IsPTXVectorType(MVT VT) {
133133
case MVT::v2f16:
134134
case MVT::v4f16:
135135
case MVT::v8f16: // <4 x f16x2>
136+
case MVT::v2bf16:
137+
case MVT::v4bf16:
138+
case MVT::v8bf16: // <4 x bf16x2>
136139
case MVT::v2f32:
137140
case MVT::v4f32:
138141
case MVT::v2f64:
@@ -190,8 +193,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
190193
// Vectors with an even number of f16 elements will be passed to
191194
// us as an array of v2f16 elements. We must match this so we
192195
// stay in sync with Ins/Outs.
193-
if (EltVT == MVT::f16 && NumElts % 2 == 0) {
194-
EltVT = MVT::v2f16;
196+
if ((EltVT == MVT::f16 || EltVT == MVT::f16) && NumElts % 2 == 0) {
197+
EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
195198
NumElts /= 2;
196199
}
197200
for (unsigned j = 0; j != NumElts; ++j) {
@@ -400,6 +403,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
400403
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
401404
addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
402405
addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
406+
addRegisterClass(MVT::bf16, &NVPTX::Float16RegsRegClass);
407+
addRegisterClass(MVT::v2bf16, &NVPTX::Float16x2RegsRegClass);
403408

404409
// Conversion to/from FP16/FP16x2 is always legal.
405410
setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
@@ -495,6 +500,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
495500
setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
496501
setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
497502
setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
503+
setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
498504

499505
// TRAP can be lowered to PTX trap
500506
setOperationAction(ISD::TRAP, MVT::Other, Legal);
@@ -2334,14 +2340,17 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
23342340
case MVT::v2i32:
23352341
case MVT::v2i64:
23362342
case MVT::v2f16:
2343+
case MVT::v2bf16:
23372344
case MVT::v2f32:
23382345
case MVT::v2f64:
23392346
case MVT::v4i8:
23402347
case MVT::v4i16:
23412348
case MVT::v4i32:
23422349
case MVT::v4f16:
2350+
case MVT::v4bf16:
23432351
case MVT::v4f32:
23442352
case MVT::v8f16: // <4 x f16x2>
2353+
case MVT::v8bf16: // <4 x bf16x2>
23452354
// This is a "native" vector type
23462355
break;
23472356
}
@@ -2386,7 +2395,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
23862395
// v8f16 is a special case. PTX doesn't have st.v8.f16
23872396
// instruction. Instead, we split the vector into v2f16 chunks and
23882397
// store them with st.v4.b32.
2389-
assert(EltVT == MVT::f16 && "Wrong type for the vector.");
2398+
assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
2399+
"Wrong type for the vector.");
23902400
Opcode = NVPTXISD::StoreV4;
23912401
StoreF16x2 = true;
23922402
break;
@@ -4987,11 +4997,12 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
49874997
// v8f16 is a special case. PTX doesn't have ld.v8.f16
49884998
// instruction. Instead, we split the vector into v2f16 chunks and
49894999
// load them with ld.v4.b32.
4990-
assert(EltVT == MVT::f16 && "Unsupported v8 vector type.");
5000+
assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
5001+
"Unsupported v8 vector type.");
49915002
LoadF16x2 = true;
49925003
Opcode = NVPTXISD::LoadV4;
4993-
EVT ListVTs[] = {MVT::v2f16, MVT::v2f16, MVT::v2f16, MVT::v2f16,
4994-
MVT::Other};
5004+
EVT VVT = (EltVT == MVT::f16) ? MVT::v2f16 : MVT::v2bf16;
5005+
EVT ListVTs[] = {VVT, VVT, VVT, VVT, MVT::Other};
49955006
LdResVTs = DAG.getVTList(ListVTs);
49965007
break;
49975008
}

0 commit comments

Comments
 (0)