Skip to content

Commit fef9547

Browse files
committed
[LoongArch][BF16] Add support for the __bf16 type
The LoongArch psABI recently added __bf16 type support. Now we can enable this new type in clang. Currently, bf16 operations are automatically supported by promoting to float. This patch adds bf16 support by ensuring that load extension / truncate store operations are properly expanded. And this commit implements support for bf16 truncate/extend on hard FP targets. The extend operation is implemented by a shift just as in the standard legalization. This requires custom lowering of the truncate libcall on hard float ABIs (the normal libcall code path is used on soft ABIs).
1 parent 106c897 commit fef9547

File tree

8 files changed

+1822
-4
lines changed

8 files changed

+1822
-4
lines changed

clang/docs/LanguageExtensions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,7 @@ to ``float``; see below for more information on this emulation.
10091009
* 64-bit ARM (AArch64)
10101010
* RISC-V
10111011
* X86 (when SSE2 is available)
1012+
* LoongArch
10121013

10131014
(For X86, SSE2 is available on 64-bit and all recent 32-bit processors.)
10141015

clang/lib/Basic/Targets/LoongArch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class LLVM_LIBRARY_VISIBILITY LoongArchTargetInfo : public TargetInfo {
4949
HasFeatureLD_SEQ_SA = false;
5050
HasFeatureDiv32 = false;
5151
HasFeatureSCQ = false;
52+
BFloat16Width = 16;
53+
BFloat16Align = 16;
54+
BFloat16Format = &llvm::APFloat::BFloat();
5255
LongDoubleWidth = 128;
5356
LongDoubleAlign = 128;
5457
LongDoubleFormat = &llvm::APFloat::IEEEquad();
@@ -99,6 +102,8 @@ class LLVM_LIBRARY_VISIBILITY LoongArchTargetInfo : public TargetInfo {
99102

100103
bool hasBitIntType() const override { return true; }
101104

105+
bool hasBFloat16Type() const override { return true; }
106+
102107
bool useFP16ConversionIntrinsics() const override { return false; }
103108

104109
bool handleTargetFeatures(std::vector<std::string> &Features,

clang/test/CodeGen/LoongArch/bfloat-abi.c

Lines changed: 532 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 2
2+
// RUN: %clang_cc1 -triple loongarch64 -emit-llvm -o - %s | FileCheck %s
3+
// RUN: %clang_cc1 -triple loongarch32 -emit-llvm -o - %s | FileCheck %s
4+
5+
// CHECK-LABEL: define dso_local void @_Z3fooDF16b
6+
// CHECK-SAME: (bfloat noundef [[B:%.*]]) #[[ATTR0:[0-9]+]] {
7+
// CHECK-NEXT: entry:
8+
// CHECK-NEXT: [[B_ADDR:%.*]] = alloca bfloat, align 2
9+
// CHECK-NEXT: store bfloat [[B]], ptr [[B_ADDR]], align 2
10+
// CHECK-NEXT: ret void
11+
//
12+
void foo(__bf16 b) {}

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
182182
if (Subtarget.hasBasicF()) {
183183
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
184184
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
185+
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
186+
setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
185187
setCondCodeAction(FPCCToExpand, MVT::f32, Expand);
186188

187189
setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
@@ -203,6 +205,9 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
203205
Subtarget.isSoftFPABI() ? LibCall : Custom);
204206
setOperationAction(ISD::FP_TO_FP16, MVT::f32,
205207
Subtarget.isSoftFPABI() ? LibCall : Custom);
208+
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom);
209+
setOperationAction(ISD::FP_TO_BF16, MVT::f32,
210+
Subtarget.isSoftFPABI() ? LibCall : Custom);
206211

207212
if (Subtarget.is64Bit())
208213
setOperationAction(ISD::FRINT, MVT::f32, Legal);
@@ -221,6 +226,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
221226
if (Subtarget.hasBasicD()) {
222227
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
223228
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
229+
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
230+
setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
224231
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
225232
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
226233
setCondCodeAction(FPCCToExpand, MVT::f64, Expand);
@@ -243,6 +250,9 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
243250
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
244251
setOperationAction(ISD::FP_TO_FP16, MVT::f64,
245252
Subtarget.isSoftFPABI() ? LibCall : Custom);
253+
setOperationAction(ISD::BF16_TO_FP, MVT::f64, Custom);
254+
setOperationAction(ISD::FP_TO_BF16, MVT::f64,
255+
Subtarget.isSoftFPABI() ? LibCall : Custom);
246256

247257
if (Subtarget.is64Bit())
248258
setOperationAction(ISD::FRINT, MVT::f64, Legal);
@@ -497,6 +507,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
497507
return lowerFP_TO_FP16(Op, DAG);
498508
case ISD::FP16_TO_FP:
499509
return lowerFP16_TO_FP(Op, DAG);
510+
case ISD::FP_TO_BF16:
511+
return lowerFP_TO_BF16(Op, DAG);
512+
case ISD::BF16_TO_FP:
513+
return lowerBF16_TO_FP(Op, DAG);
500514
}
501515
return SDValue();
502516
}
@@ -2283,6 +2297,36 @@ SDValue LoongArchTargetLowering::lowerFP16_TO_FP(SDValue Op,
22832297
return Res;
22842298
}
22852299

2300+
SDValue LoongArchTargetLowering::lowerFP_TO_BF16(SDValue Op,
2301+
SelectionDAG &DAG) const {
2302+
assert(Subtarget.hasBasicF() && "Unexpected custom legalization");
2303+
SDLoc DL(Op);
2304+
MakeLibCallOptions CallOptions;
2305+
RTLIB::Libcall LC =
2306+
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
2307+
SDValue Res =
2308+
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
2309+
if (Subtarget.is64Bit())
2310+
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Res);
2311+
return DAG.getBitcast(MVT::i32, Res);
2312+
}
2313+
2314+
SDValue LoongArchTargetLowering::lowerBF16_TO_FP(SDValue Op,
2315+
SelectionDAG &DAG) const {
2316+
assert(Subtarget.hasBasicF() && "Unexpected custom legalization");
2317+
MVT VT = Op.getSimpleValueType();
2318+
SDLoc DL(Op);
2319+
Op = DAG.getNode(
2320+
ISD::SHL, DL, Op.getOperand(0).getValueType(), Op.getOperand(0),
2321+
DAG.getShiftAmountConstant(16, Op.getOperand(0).getValueType(), DL));
2322+
SDValue Res = Subtarget.is64Bit() ? DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64,
2323+
DL, MVT::f32, Op)
2324+
: DAG.getBitcast(MVT::f32, Op);
2325+
if (VT != MVT::f32)
2326+
return DAG.getNode(ISD::FP_EXTEND, DL, VT, Res);
2327+
return Res;
2328+
}
2329+
22862330
static bool isConstantOrUndef(const SDValue Op) {
22872331
if (Op->isUndef())
22882332
return true;
@@ -7714,8 +7758,9 @@ bool LoongArchTargetLowering::splitValueIntoRegisterParts(
77147758
bool IsABIRegCopy = CC.has_value();
77157759
EVT ValueVT = Val.getValueType();
77167760

7717-
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
7718-
// Cast the f16 to i16, extend to i32, pad with ones to make a float
7761+
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
7762+
PartVT == MVT::f32) {
7763+
// Cast the [b]f16 to i16, extend to i32, pad with ones to make a float
77197764
// nan, and cast to f32.
77207765
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
77217766
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
@@ -7734,10 +7779,11 @@ SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue(
77347779
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
77357780
bool IsABIRegCopy = CC.has_value();
77367781

7737-
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
7782+
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
7783+
PartVT == MVT::f32) {
77387784
SDValue Val = Parts[0];
77397785

7740-
// Cast the f32 to i32, truncate to i16, and cast back to f16.
7786+
// Cast the f32 to i32, truncate to i16, and cast back to [b]f16.
77417787
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
77427788
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
77437789
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ class LoongArchTargetLowering : public TargetLowering {
363363
SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) const;
364364
SDValue lowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
365365
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
366+
SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
367+
SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
366368

367369
bool isFPImmLegal(const APFloat &Imm, EVT VT,
368370
bool ForCodeSize) const override;
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=loongarch64 -mattr=+d -target-abi=lp64d < %s | FileCheck --check-prefixes=CHECK,LA64 %s
3+
; RUN: llc -mtriple=loongarch32 -mattr=+d -target-abi=ilp32d < %s | FileCheck --check-prefixes=CHECK,LA32 %s
4+
5+
define void @test_load_store(ptr %p, ptr %q) nounwind {
6+
; CHECK-LABEL: test_load_store:
7+
; CHECK: # %bb.0:
8+
; CHECK-NEXT: ld.h $a0, $a0, 0
9+
; CHECK-NEXT: st.h $a0, $a1, 0
10+
; CHECK-NEXT: ret
11+
%a = load bfloat, ptr %p
12+
store bfloat %a, ptr %q
13+
ret void
14+
}
15+
16+
define float @test_fpextend_float(ptr %p) nounwind {
17+
; LA64-LABEL: test_fpextend_float:
18+
; LA64: # %bb.0:
19+
; LA64-NEXT: ld.hu $a0, $a0, 0
20+
; LA64-NEXT: slli.d $a0, $a0, 16
21+
; LA64-NEXT: movgr2fr.w $fa0, $a0
22+
; LA64-NEXT: ret
23+
;
24+
; LA32-LABEL: test_fpextend_float:
25+
; LA32: # %bb.0:
26+
; LA32-NEXT: ld.hu $a0, $a0, 0
27+
; LA32-NEXT: slli.w $a0, $a0, 16
28+
; LA32-NEXT: movgr2fr.w $fa0, $a0
29+
; LA32-NEXT: ret
30+
%a = load bfloat, ptr %p
31+
%r = fpext bfloat %a to float
32+
ret float %r
33+
}
34+
35+
define double @test_fpextend_double(ptr %p) nounwind {
36+
; LA64-LABEL: test_fpextend_double:
37+
; LA64: # %bb.0:
38+
; LA64-NEXT: ld.hu $a0, $a0, 0
39+
; LA64-NEXT: slli.d $a0, $a0, 16
40+
; LA64-NEXT: movgr2fr.w $fa0, $a0
41+
; LA64-NEXT: fcvt.d.s $fa0, $fa0
42+
; LA64-NEXT: ret
43+
;
44+
; LA32-LABEL: test_fpextend_double:
45+
; LA32: # %bb.0:
46+
; LA32-NEXT: ld.hu $a0, $a0, 0
47+
; LA32-NEXT: slli.w $a0, $a0, 16
48+
; LA32-NEXT: movgr2fr.w $fa0, $a0
49+
; LA32-NEXT: fcvt.d.s $fa0, $fa0
50+
; LA32-NEXT: ret
51+
%a = load bfloat, ptr %p
52+
%r = fpext bfloat %a to double
53+
ret double %r
54+
}
55+
56+
define void @test_fptrunc_float(float %f, ptr %p) nounwind {
57+
; LA64-LABEL: test_fptrunc_float:
58+
; LA64: # %bb.0:
59+
; LA64-NEXT: addi.d $sp, $sp, -16
60+
; LA64-NEXT: st.d $ra, $sp, 8 # 8-byte Folded Spill
61+
; LA64-NEXT: st.d $fp, $sp, 0 # 8-byte Folded Spill
62+
; LA64-NEXT: move $fp, $a0
63+
; LA64-NEXT: pcaddu18i $ra, %call36(__truncsfbf2)
64+
; LA64-NEXT: jirl $ra, $ra, 0
65+
; LA64-NEXT: movfr2gr.s $a0, $fa0
66+
; LA64-NEXT: st.h $a0, $fp, 0
67+
; LA64-NEXT: ld.d $fp, $sp, 0 # 8-byte Folded Reload
68+
; LA64-NEXT: ld.d $ra, $sp, 8 # 8-byte Folded Reload
69+
; LA64-NEXT: addi.d $sp, $sp, 16
70+
; LA64-NEXT: ret
71+
;
72+
; LA32-LABEL: test_fptrunc_float:
73+
; LA32: # %bb.0:
74+
; LA32-NEXT: addi.w $sp, $sp, -16
75+
; LA32-NEXT: st.w $ra, $sp, 12 # 4-byte Folded Spill
76+
; LA32-NEXT: st.w $fp, $sp, 8 # 4-byte Folded Spill
77+
; LA32-NEXT: move $fp, $a0
78+
; LA32-NEXT: bl __truncsfbf2
79+
; LA32-NEXT: movfr2gr.s $a0, $fa0
80+
; LA32-NEXT: st.h $a0, $fp, 0
81+
; LA32-NEXT: ld.w $fp, $sp, 8 # 4-byte Folded Reload
82+
; LA32-NEXT: ld.w $ra, $sp, 12 # 4-byte Folded Reload
83+
; LA32-NEXT: addi.w $sp, $sp, 16
84+
; LA32-NEXT: ret
85+
%a = fptrunc float %f to bfloat
86+
store bfloat %a, ptr %p
87+
ret void
88+
}
89+
90+
define void @test_fptrunc_double(double %d, ptr %p) nounwind {
91+
; LA64-LABEL: test_fptrunc_double:
92+
; LA64: # %bb.0:
93+
; LA64-NEXT: addi.d $sp, $sp, -16
94+
; LA64-NEXT: st.d $ra, $sp, 8 # 8-byte Folded Spill
95+
; LA64-NEXT: st.d $fp, $sp, 0 # 8-byte Folded Spill
96+
; LA64-NEXT: move $fp, $a0
97+
; LA64-NEXT: pcaddu18i $ra, %call36(__truncdfbf2)
98+
; LA64-NEXT: jirl $ra, $ra, 0
99+
; LA64-NEXT: movfr2gr.s $a0, $fa0
100+
; LA64-NEXT: st.h $a0, $fp, 0
101+
; LA64-NEXT: ld.d $fp, $sp, 0 # 8-byte Folded Reload
102+
; LA64-NEXT: ld.d $ra, $sp, 8 # 8-byte Folded Reload
103+
; LA64-NEXT: addi.d $sp, $sp, 16
104+
; LA64-NEXT: ret
105+
;
106+
; LA32-LABEL: test_fptrunc_double:
107+
; LA32: # %bb.0:
108+
; LA32-NEXT: addi.w $sp, $sp, -16
109+
; LA32-NEXT: st.w $ra, $sp, 12 # 4-byte Folded Spill
110+
; LA32-NEXT: st.w $fp, $sp, 8 # 4-byte Folded Spill
111+
; LA32-NEXT: move $fp, $a0
112+
; LA32-NEXT: bl __truncdfbf2
113+
; LA32-NEXT: movfr2gr.s $a0, $fa0
114+
; LA32-NEXT: st.h $a0, $fp, 0
115+
; LA32-NEXT: ld.w $fp, $sp, 8 # 4-byte Folded Reload
116+
; LA32-NEXT: ld.w $ra, $sp, 12 # 4-byte Folded Reload
117+
; LA32-NEXT: addi.w $sp, $sp, 16
118+
; LA32-NEXT: ret
119+
%a = fptrunc double %d to bfloat
120+
store bfloat %a, ptr %p
121+
ret void
122+
}
123+
124+
define void @test_fadd(ptr %p, ptr %q) nounwind {
125+
; LA64-LABEL: test_fadd:
126+
; LA64: # %bb.0:
127+
; LA64-NEXT: addi.d $sp, $sp, -16
128+
; LA64-NEXT: st.d $ra, $sp, 8 # 8-byte Folded Spill
129+
; LA64-NEXT: st.d $fp, $sp, 0 # 8-byte Folded Spill
130+
; LA64-NEXT: ld.hu $a1, $a1, 0
131+
; LA64-NEXT: move $fp, $a0
132+
; LA64-NEXT: ld.hu $a0, $a0, 0
133+
; LA64-NEXT: slli.d $a1, $a1, 16
134+
; LA64-NEXT: movgr2fr.w $fa0, $a1
135+
; LA64-NEXT: slli.d $a0, $a0, 16
136+
; LA64-NEXT: movgr2fr.w $fa1, $a0
137+
; LA64-NEXT: fadd.s $fa0, $fa1, $fa0
138+
; LA64-NEXT: pcaddu18i $ra, %call36(__truncsfbf2)
139+
; LA64-NEXT: jirl $ra, $ra, 0
140+
; LA64-NEXT: movfr2gr.s $a0, $fa0
141+
; LA64-NEXT: st.h $a0, $fp, 0
142+
; LA64-NEXT: ld.d $fp, $sp, 0 # 8-byte Folded Reload
143+
; LA64-NEXT: ld.d $ra, $sp, 8 # 8-byte Folded Reload
144+
; LA64-NEXT: addi.d $sp, $sp, 16
145+
; LA64-NEXT: ret
146+
;
147+
; LA32-LABEL: test_fadd:
148+
; LA32: # %bb.0:
149+
; LA32-NEXT: addi.w $sp, $sp, -16
150+
; LA32-NEXT: st.w $ra, $sp, 12 # 4-byte Folded Spill
151+
; LA32-NEXT: st.w $fp, $sp, 8 # 4-byte Folded Spill
152+
; LA32-NEXT: ld.hu $a1, $a1, 0
153+
; LA32-NEXT: move $fp, $a0
154+
; LA32-NEXT: ld.hu $a0, $a0, 0
155+
; LA32-NEXT: slli.w $a1, $a1, 16
156+
; LA32-NEXT: movgr2fr.w $fa0, $a1
157+
; LA32-NEXT: slli.w $a0, $a0, 16
158+
; LA32-NEXT: movgr2fr.w $fa1, $a0
159+
; LA32-NEXT: fadd.s $fa0, $fa1, $fa0
160+
; LA32-NEXT: bl __truncsfbf2
161+
; LA32-NEXT: movfr2gr.s $a0, $fa0
162+
; LA32-NEXT: st.h $a0, $fp, 0
163+
; LA32-NEXT: ld.w $fp, $sp, 8 # 4-byte Folded Reload
164+
; LA32-NEXT: ld.w $ra, $sp, 12 # 4-byte Folded Reload
165+
; LA32-NEXT: addi.w $sp, $sp, 16
166+
; LA32-NEXT: ret
167+
%a = load bfloat, ptr %p
168+
%b = load bfloat, ptr %q
169+
%r = fadd bfloat %a, %b
170+
store bfloat %r, ptr %p
171+
ret void
172+
}

0 commit comments

Comments
 (0)