Skip to content

Commit addd0bb

Browse files
mallick-qcMuntasir Mallickkaushik-quicinc
authored andcommitted
[Hexagon] Enable soft bf16 in hexagon
This patch adds: 1. Support to recognize bf16 type in the frontend and isel/abi support for scalar bf16 programs Limitations: fp_to_bf16 is being generated with a tablegen pattern instead of lowering via expansion. This is because we do not have support for fcanonincalize instruction which should prevent an SNaN being converted to an infinity due to truncation. 2. Vector codegen support for bf16 Patch By: Fateme Hosseini Co-authored-by: Muntasir Mallick <[email protected]> Co-authored-by: Kaushik Kulkarni <[email protected]> Change-Id: I767145458dafcaf7691eb9ab4e03d33e5fd03a6a
1 parent 965b338 commit addd0bb

File tree

11 files changed

+900
-484
lines changed

11 files changed

+900
-484
lines changed

clang/lib/Basic/Targets/Hexagon.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,14 @@ bool HexagonTargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
155155
HasFastHalfType = true;
156156
HasFloat16 = true;
157157
}
158+
if (CPU.compare("hexagonv81") >= 0)
159+
HasBFloat16 = true;
160+
158161
return true;
159162
}
160163

164+
bool HexagonTargetInfo::hasBFloat16Type() const { return HasBFloat16; }
165+
161166
const char *const HexagonTargetInfo::GCCRegNames[] = {
162167
// Scalar registers:
163168
"r0", "r1", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", "r11",

clang/lib/Basic/Targets/Hexagon.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class LLVM_LIBRARY_VISIBILITY HexagonTargetInfo : public TargetInfo {
6464
// for modeling predicate registers in HVX, and the bool -> byte
6565
// correspondence matches the HVX architecture.
6666
BoolWidth = BoolAlign = 8;
67+
BFloat16Width = BFloat16Align = 16;
68+
BFloat16Format = &llvm::APFloat::BFloat();
6769
}
6870

6971
llvm::SmallVector<Builtin::InfosShard> getTargetBuiltins() const override;
@@ -95,6 +97,8 @@ class LLVM_LIBRARY_VISIBILITY HexagonTargetInfo : public TargetInfo {
9597

9698
bool hasFeature(StringRef Feature) const override;
9799

100+
bool hasBFloat16Type() const override;
101+
98102
bool
99103
initFeatureMap(llvm::StringMap<bool> &Features, DiagnosticsEngine &Diags,
100104
StringRef CPU,

llvm/lib/Target/Hexagon/HexagonCallingConv.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def CC_HexagonStack: CallingConv<[
2525
def CC_Hexagon_Legacy: CallingConv<[
2626
CCIfType<[i1,i8,i16],
2727
CCPromoteToType<i32>>,
28+
CCIfType<[bf16],
29+
CCBitConvertToType<i32>>,
2830
CCIfType<[f32],
2931
CCBitConvertToType<i32>>,
3032
CCIfType<[f64],
@@ -55,6 +57,8 @@ def CC_Hexagon_Legacy: CallingConv<[
5557
def CC_Hexagon: CallingConv<[
5658
CCIfType<[i1,i8,i16],
5759
CCPromoteToType<i32>>,
60+
CCIfType<[bf16],
61+
CCBitConvertToType<i32>>,
5862
CCIfType<[f32],
5963
CCBitConvertToType<i32>>,
6064
CCIfType<[f64],
@@ -88,6 +92,8 @@ def CC_Hexagon: CallingConv<[
8892
def RetCC_Hexagon: CallingConv<[
8993
CCIfType<[i1,i8,i16],
9094
CCPromoteToType<i32>>,
95+
CCIfType<[bf16],
96+
CCBitConvertToType<i32>>,
9197
CCIfType<[f32],
9298
CCBitConvertToType<i32>>,
9399
CCIfType<[f64],
@@ -149,16 +155,16 @@ def CC_Hexagon_HVX: CallingConv<[
149155
CCIfType<[v128i1], CCPromoteToType<v128i8>>>,
150156

151157
CCIfHvx128<
152-
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
158+
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
153159
CCAssignToReg<[V0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15]>>>,
154160
CCIfHvx128<
155-
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
161+
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v128bf16],
156162
CCAssignToReg<[W0,W1,W2,W3,W4,W5,W6,W7]>>>,
157163
CCIfHvx128<
158-
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
164+
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
159165
CCAssignToStack<128,128>>>,
160166
CCIfHvx128<
161-
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
167+
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v64bf16],
162168
CCAssignToStack<256,128>>>,
163169

164170
CCDelegateTo<CC_Hexagon>
@@ -175,10 +181,10 @@ def RetCC_Hexagon_HVX: CallingConv<[
175181

176182
// HVX 128-byte mode
177183
CCIfHvx128<
178-
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
184+
CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
179185
CCAssignToReg<[V0]>>>,
180186
CCIfHvx128<
181-
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
187+
CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v128bf16],
182188
CCAssignToReg<[W0]>>>,
183189

184190
CCDelegateTo<RetCC_Hexagon>

llvm/lib/Target/Hexagon/HexagonISelLowering.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,6 +1677,8 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
16771677
}
16781678
// Turn FP truncstore into trunc + store.
16791679
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
1680+
setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
1681+
setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
16801682
// Turn FP extload into load/fpextend.
16811683
for (MVT VT : MVT::fp_valuetypes())
16821684
setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
@@ -1872,9 +1874,15 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
18721874
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
18731875
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
18741876
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
1877+
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
1878+
setOperationAction(ISD::BF16_TO_FP, MVT::f64, Expand);
1879+
setOperationAction(ISD::FP_TO_BF16, MVT::f64, Expand);
18751880

18761881
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
18771882
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
1883+
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
1884+
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
1885+
18781886
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
18791887
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
18801888

llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ HexagonTargetLowering::initializeHVXLowering() {
8888
addRegisterClass(MVT::v64f32, &Hexagon::HvxWRRegClass);
8989
addRegisterClass(MVT::v128f16, &Hexagon::HvxWRRegClass);
9090
}
91+
if (Subtarget.useHVXV81Ops()) {
92+
addRegisterClass(MVT::v64bf16, &Hexagon::HvxVRRegClass);
93+
addRegisterClass(MVT::v128bf16, &Hexagon::HvxWRRegClass);
94+
}
9195
}
9296

9397
// Set up operation actions.
@@ -162,6 +166,30 @@ HexagonTargetLowering::initializeHVXLowering() {
162166
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v64f32, ByteW);
163167
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v32f32, ByteV);
164168

169+
if (Subtarget.useHVXV81Ops()) {
170+
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v128bf16, ByteW);
171+
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v64bf16, ByteV);
172+
setPromoteTo(ISD::SETCC, MVT::v64bf16, MVT::v64f32);
173+
setPromoteTo(ISD::FADD, MVT::v64bf16, MVT::v64f32);
174+
setPromoteTo(ISD::FSUB, MVT::v64bf16, MVT::v64f32);
175+
setPromoteTo(ISD::FMUL, MVT::v64bf16, MVT::v64f32);
176+
setPromoteTo(ISD::FMINNUM, MVT::v64bf16, MVT::v64f32);
177+
setPromoteTo(ISD::FMAXNUM, MVT::v64bf16, MVT::v64f32);
178+
179+
setOperationAction(ISD::SPLAT_VECTOR, MVT::v64bf16, Legal);
180+
setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64bf16, Custom);
181+
setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v64bf16, Custom);
182+
183+
setOperationAction(ISD::MLOAD, MVT::v64bf16, Custom);
184+
setOperationAction(ISD::MSTORE, MVT::v64bf16, Custom);
185+
setOperationAction(ISD::BUILD_VECTOR, MVT::v64bf16, Custom);
186+
setOperationAction(ISD::CONCAT_VECTORS, MVT::v64bf16, Custom);
187+
188+
setOperationAction(ISD::SPLAT_VECTOR, MVT::bf16, Custom);
189+
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::bf16, Custom);
190+
setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
191+
}
192+
165193
for (MVT P : FloatW) {
166194
setOperationAction(ISD::LOAD, P, Custom);
167195
setOperationAction(ISD::STORE, P, Custom);
@@ -1667,14 +1695,15 @@ HexagonTargetLowering::LowerHvxBuildVector(SDValue Op, SelectionDAG &DAG)
16671695
// In case of MVT::f16 BUILD_VECTOR, since MVT::f16 is
16681696
// not a legal type, just bitcast the node to use i16
16691697
// types and bitcast the result back to f16
1670-
if (VecTy.getVectorElementType() == MVT::f16) {
1671-
SmallVector<SDValue,64> NewOps;
1698+
if (VecTy.getVectorElementType() == MVT::f16 ||
1699+
VecTy.getVectorElementType() == MVT::bf16) {
1700+
SmallVector<SDValue, 64> NewOps;
16721701
for (unsigned i = 0; i != Size; i++)
16731702
NewOps.push_back(DAG.getBitcast(MVT::i16, Ops[i]));
16741703

1675-
SDValue T0 = DAG.getNode(ISD::BUILD_VECTOR, dl,
1676-
tyVector(VecTy, MVT::i16), NewOps);
1677-
return DAG.getBitcast(tyVector(VecTy, MVT::f16), T0);
1704+
SDValue T0 =
1705+
DAG.getNode(ISD::BUILD_VECTOR, dl, tyVector(VecTy, MVT::i16), NewOps);
1706+
return DAG.getBitcast(tyVector(VecTy, VecTy.getVectorElementType()), T0);
16781707
}
16791708

16801709
// First, split the BUILD_VECTOR for vector pairs. We could generate
@@ -1698,7 +1727,7 @@ HexagonTargetLowering::LowerHvxSplatVector(SDValue Op, SelectionDAG &DAG)
16981727
MVT VecTy = ty(Op);
16991728
MVT ArgTy = ty(Op.getOperand(0));
17001729

1701-
if (ArgTy == MVT::f16) {
1730+
if (ArgTy == MVT::f16 || ArgTy == MVT::bf16) {
17021731
MVT SplatTy = MVT::getVectorVT(MVT::i16, VecTy.getVectorNumElements());
17031732
SDValue ToInt16 = DAG.getBitcast(MVT::i16, Op.getOperand(0));
17041733
SDValue ToInt32 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, ToInt16);
@@ -1831,12 +1860,12 @@ HexagonTargetLowering::LowerHvxInsertElement(SDValue Op, SelectionDAG &DAG)
18311860
if (ElemTy == MVT::i1)
18321861
return insertHvxElementPred(VecV, IdxV, ValV, dl, DAG);
18331862

1834-
if (ElemTy == MVT::f16) {
1863+
if (ElemTy == MVT::f16 || ElemTy == MVT::bf16) {
18351864
SDValue T0 = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl,
18361865
tyVector(VecTy, MVT::i16),
18371866
DAG.getBitcast(tyVector(VecTy, MVT::i16), VecV),
18381867
DAG.getBitcast(MVT::i16, ValV), IdxV);
1839-
return DAG.getBitcast(tyVector(VecTy, MVT::f16), T0);
1868+
return DAG.getBitcast(tyVector(VecTy, ElemTy), T0);
18401869
}
18411870

18421871
return insertHvxElementReg(VecV, IdxV, ValV, dl, DAG);
@@ -2334,6 +2363,20 @@ SDValue HexagonTargetLowering::LowerHvxFpExtend(SDValue Op,
23342363
MVT VecTy = ty(Op);
23352364
MVT ArgTy = ty(Op.getOperand(0));
23362365
const SDLoc &dl(Op);
2366+
2367+
if (ArgTy == MVT::v64bf16) {
2368+
MVT HalfTy = typeSplit(VecTy).first;
2369+
SDValue BF16Vec = Op.getOperand(0);
2370+
SDValue Zeroes = getInstr(Hexagon::V6_vxor, dl, HalfTy, {BF16Vec, BF16Vec}, DAG);
2371+
// Interleave zero vector with the bf16 vector, with zeroes in the lower half
2372+
// of each 32 bit lane, effectively extending the bf16 values to fp32 values.
2373+
SDValue ShuffVec = getInstr(Hexagon::V6_vshufoeh, dl, VecTy, {BF16Vec, Zeroes}, DAG);
2374+
VectorPair VecPair = opSplit(ShuffVec, dl, DAG);
2375+
SDValue Result = getInstr(Hexagon::V6_vshuffvdd, dl, VecTy,
2376+
{VecPair.second, VecPair.first, DAG.getSignedConstant(-4, dl, MVT::i32)}, DAG);
2377+
return Result;
2378+
}
2379+
23372380
assert(VecTy == MVT::v64f32 && ArgTy == MVT::v64f16);
23382381

23392382
SDValue F16Vec = Op.getOperand(0);

llvm/lib/Target/Hexagon/HexagonPatterns.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def Fptoui: pf1<fp_to_uint>;
391391
def Sitofp: pf1<sint_to_fp>;
392392
def Uitofp: pf1<uint_to_fp>;
393393

394-
395394
// --(1) Immediate -------------------------------------------------------
396395
//
397396

@@ -474,6 +473,18 @@ def: OpR_R_pat<F2_conv_df2uw_chop, pf1<fp_to_uint>, i32, F64>;
474473
def: OpR_R_pat<F2_conv_sf2ud_chop, pf1<fp_to_uint>, i64, F32>;
475474
def: OpR_R_pat<F2_conv_df2ud_chop, pf1<fp_to_uint>, i64, F64>;
476475

476+
def: Pat<(i32 (fp_to_bf16 F32:$v)),
477+
(C2_mux (F2_sfclass F32:$v, 0x10), (A2_tfrsi(i32 0x7fff)),
478+
(C2_mux
479+
(C2_cmpeq
480+
(A2_and F32:$v, (A2_tfrsi (i32 0x1FFFF))),
481+
(A2_tfrsi (i32 0x08000))),
482+
(A2_and (A2_asrh F32:$v), (A2_tfrsi (i32 65535))),
483+
(A2_and
484+
(A2_asrh
485+
(A2_add F32:$v, (A2_and F32:$v, (A2_tfrsi (i32 0x8000))))),
486+
(A2_tfrsi (i32 65535))))
487+
)>;
477488
// Bitcast is different than [fp|sint|uint]_to_[sint|uint|fp].
478489
def: Pat<(i32 (bitconvert F32:$v)), (I32:$v)>;
479490
def: Pat<(f32 (bitconvert I32:$v)), (F32:$v)>;

0 commit comments

Comments
 (0)