Skip to content

Commit 4e3266f

Browse files
authored
[RISCV] Implement load/store support for XAndesBFHCvt (#150350)
We use `lh` to load 2 bytes from memory into a gpr, then mask this gpr with -65536 to emulate nan-boxing behavior, and then the value in gpr is moved to fpr using `fmv.w.x`. To move the value back from fpr to gpr, we use `fmv.x.w` and finally, `sh` is used to store the lower 2 bytes back to memory. If zfh is enabled at the same time, we can just use flh/fsw to load/store bf16 directly.
1 parent b0dea47 commit 4e3266f

File tree

4 files changed

+133
-2
lines changed

4 files changed

+133
-2
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
16181618
}
16191619
}
16201620

1621+
// Customize load and store operation for bf16 if zfh isn't enabled.
1622+
if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
1623+
setOperationAction(ISD::LOAD, MVT::bf16, Custom);
1624+
setOperationAction(ISD::STORE, MVT::bf16, Custom);
1625+
}
1626+
16211627
// Function alignments.
16221628
const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
16231629
setMinFunctionAlignment(FunctionAlignment);
@@ -7216,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
72167222
return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
72177223
}
72187224

7225+
SDValue
7226+
RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
7227+
SelectionDAG &DAG) const {
7228+
assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
7229+
"Unexpected bfloat16 load lowering");
7230+
7231+
SDLoc DL(Op);
7232+
LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
7233+
EVT MemVT = LD->getMemoryVT();
7234+
SDValue Load = DAG.getExtLoad(
7235+
ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
7236+
LD->getBasePtr(),
7237+
EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
7238+
LD->getMemOperand());
7239+
// Using mask to make bf16 nan-boxing valid when we don't have flh
7240+
// instruction. -65536 would be treat as a small number and thus it can be
7241+
// directly used lui to get the constant.
7242+
SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
7243+
SDValue OrSixteenOne =
7244+
DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
7245+
SDValue ConvertedResult =
7246+
DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
7247+
return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
7248+
}
7249+
7250+
SDValue
7251+
RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
7252+
SelectionDAG &DAG) const {
7253+
assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
7254+
"Unexpected bfloat16 store lowering");
7255+
7256+
StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
7257+
SDLoc DL(Op);
7258+
SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
7259+
Subtarget.getXLenVT(), ST->getValue());
7260+
return DAG.getTruncStore(
7261+
ST->getChain(), DL, FMV, ST->getBasePtr(),
7262+
EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
7263+
ST->getMemOperand());
7264+
}
7265+
72197266
SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
72207267
SelectionDAG &DAG) const {
72217268
switch (Op.getOpcode()) {
@@ -7914,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
79147961
return DAG.getMergeValues({Pair, Chain}, DL);
79157962
}
79167963

7964+
if (VT == MVT::bf16)
7965+
return lowerXAndesBfHCvtBFloat16Load(Op, DAG);
7966+
79177967
// Handle normal vector tuple load.
79187968
if (VT.isRISCVVectorTuple()) {
79197969
SDLoc DL(Op);
@@ -7998,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
79988048
{Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
79998049
Store->getMemOperand());
80008050
}
8051+
8052+
if (VT == MVT::bf16)
8053+
return lowerXAndesBfHCvtBFloat16Store(Op, DAG);
8054+
80018055
// Handle normal vector tuple store.
80028056
if (VT.isRISCVVectorTuple()) {
80038057
SDLoc DL(Op);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,9 @@ class RISCVTargetLowering : public TargetLowering {
578578
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
579579
SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
580580

581+
SDValue lowerXAndesBfHCvtBFloat16Load(SDValue Op, SelectionDAG &DAG) const;
582+
SDValue lowerXAndesBfHCvtBFloat16Store(SDValue Op, SelectionDAG &DAG) const;
583+
581584
bool isEligibleForTailCallOptimization(
582585
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
583586
const SmallVector<CCValAssign, 16> &ArgLocs) const;

llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
//===----------------------------------------------------------------------===//
14+
// RISC-V specific DAG Nodes.
15+
//===----------------------------------------------------------------------===//
16+
17+
def SDT_NDS_FMV_BF16_X
18+
: SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, XLenVT>]>;
19+
def SDT_NDS_FMV_X_ANYEXTBF16
20+
: SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisVT<1, bf16>]>;
21+
22+
def riscv_nds_fmv_bf16_x
23+
: SDNode<"RISCVISD::NDS_FMV_BF16_X", SDT_NDS_FMV_BF16_X>;
24+
def riscv_nds_fmv_x_anyextbf16
25+
: SDNode<"RISCVISD::NDS_FMV_X_ANYEXTBF16", SDT_NDS_FMV_X_ANYEXTBF16>;
26+
1327
//===----------------------------------------------------------------------===//
1428
// Operand and SDNode transformation definitions.
1529
//===----------------------------------------------------------------------===//
@@ -773,6 +787,25 @@ def : Pat<(bf16 (fpround FPR32:$rs)),
773787
(NDS_FCVT_BF16_S FPR32:$rs)>;
774788
} // Predicates = [HasVendorXAndesBFHCvt]
775789

790+
let isCodeGenOnly = 1 in {
791+
def NDS_FMV_BF16_X : FPUnaryOp_r<0b1111000, 0b00000, 0b000, FPR16, GPR, "fmv.w.x">,
792+
Sched<[WriteFMovI32ToF32, ReadFMovI32ToF32]>;
793+
def NDS_FMV_X_BF16 : FPUnaryOp_r<0b1110000, 0b00000, 0b000, GPR, FPR16, "fmv.x.w">,
794+
Sched<[WriteFMovF32ToI32, ReadFMovF32ToI32]>;
795+
}
796+
797+
let Predicates = [HasVendorXAndesBFHCvt] in {
798+
def : Pat<(riscv_nds_fmv_bf16_x GPR:$src), (NDS_FMV_BF16_X GPR:$src)>;
799+
def : Pat<(riscv_nds_fmv_x_anyextbf16 (bf16 FPR16:$src)),
800+
(NDS_FMV_X_BF16 (bf16 FPR16:$src))>;
801+
} // Predicates = [HasVendorXAndesBFHCvt]
802+
803+
// Use flh/fsh to load/store bf16 if zfh is enabled.
804+
let Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt] in {
805+
def : LdPat<load, FLH, bf16>;
806+
def : StPat<store, FSH, FPR16, bf16>;
807+
} // Predicates = [HasStdExtZfh, HasVendorXAndesBFHCvt]
808+
776809
let Predicates = [HasVendorXAndesVBFHCvt] in {
777810
defm PseudoNDS_VFWCVT_S_BF16 : VPseudoVWCVT_S_BF16;
778811
defm PseudoNDS_VFNCVT_BF16_S : VPseudoVNCVT_BF16_S;

llvm/test/CodeGen/RISCV/xandesbfhcvt.ll

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc -mtriple=riscv32 -mattr=+xandesbfhcvt -target-abi ilp32f \
3-
; RUN: -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
4+
; RUN: llc -mtriple=riscv32 -mattr=+zfh,+xandesbfhcvt -target-abi ilp32f \
5+
; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
46
; RUN: llc -mtriple=riscv64 -mattr=+xandesbfhcvt -target-abi lp64f \
5-
; RUN: -verify-machineinstrs < %s | FileCheck %s
7+
; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,XANDESBFHCVT %s
8+
; RUN: llc -mtriple=riscv64 -mattr=+zfh,+xandesbfhcvt -target-abi lp64f \
9+
; RUN: -verify-machineinstrs < %s | FileCheck --check-prefixes=CHECK,ZFH %s
610

711
define float @fcvt_s_bf16(bfloat %a) nounwind {
812
; CHECK-LABEL: fcvt_s_bf16:
@@ -21,3 +25,40 @@ define bfloat @fcvt_bf16_s(float %a) nounwind {
2125
%1 = fptrunc float %a to bfloat
2226
ret bfloat %1
2327
}
28+
29+
; Check load and store to bf16.
30+
define void @loadstorebf16(ptr %bf, ptr %sf) nounwind {
31+
; XANDESBFHCVT-LABEL: loadstorebf16:
32+
; XANDESBFHCVT: # %bb.0: # %entry
33+
; XANDESBFHCVT-NEXT: lhu a2, 0(a0)
34+
; XANDESBFHCVT-NEXT: lui a3, 1048560
35+
; XANDESBFHCVT-NEXT: or a2, a2, a3
36+
; XANDESBFHCVT-NEXT: fmv.w.x fa5, a2
37+
; XANDESBFHCVT-NEXT: nds.fcvt.s.bf16 fa5, fa5
38+
; XANDESBFHCVT-NEXT: fsw fa5, 0(a1)
39+
; XANDESBFHCVT-NEXT: flw fa5, 0(a1)
40+
; XANDESBFHCVT-NEXT: nds.fcvt.bf16.s fa5, fa5
41+
; XANDESBFHCVT-NEXT: fmv.x.w a1, fa5
42+
; XANDESBFHCVT-NEXT: sh a1, 0(a0)
43+
; XANDESBFHCVT-NEXT: ret
44+
;
45+
; ZFH-LABEL: loadstorebf16:
46+
; ZFH: # %bb.0: # %entry
47+
; ZFH-NEXT: flh fa5, 0(a0)
48+
; ZFH-NEXT: nds.fcvt.s.bf16 fa5, fa5
49+
; ZFH-NEXT: fsw fa5, 0(a1)
50+
; ZFH-NEXT: flw fa5, 0(a1)
51+
; ZFH-NEXT: nds.fcvt.bf16.s fa5, fa5
52+
; ZFH-NEXT: fsh fa5, 0(a0)
53+
; ZFH-NEXT: ret
54+
entry:
55+
%0 = load bfloat, bfloat* %bf, align 2
56+
%1 = fpext bfloat %0 to float
57+
store volatile float %1, float* %sf, align 4
58+
59+
%2 = load float, float* %sf, align 4
60+
%3 = fptrunc float %2 to bfloat
61+
store volatile bfloat %3, bfloat* %bf, align 2
62+
63+
ret void
64+
}

0 commit comments

Comments
 (0)