@@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
16181618 }
16191619 }
16201620
1621+ // Customize load and store operation for bf16 if zfh isn't enabled.
1622+ if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
1623+ setOperationAction(ISD::LOAD, MVT::bf16, Custom);
1624+ setOperationAction(ISD::STORE, MVT::bf16, Custom);
1625+ }
1626+
16211627 // Function alignments.
16221628 const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
16231629 setMinFunctionAlignment(FunctionAlignment);
@@ -7216,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
72167222 return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
72177223}
72187224
7225+ SDValue
7226+ RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
7227+ SelectionDAG &DAG) const {
7228+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
7229+ "Unexpected bfloat16 load lowering");
7230+
7231+ SDLoc DL(Op);
7232+ LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
7233+ EVT MemVT = LD->getMemoryVT();
7234+ SDValue Load = DAG.getExtLoad(
7235+ ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
7236+ LD->getBasePtr(),
7237+ EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
7238+ LD->getMemOperand());
7239+ // Using mask to make bf16 nan-boxing valid when we don't have flh
7240+ // instruction. -65536 would be treat as a small number and thus it can be
7241+ // directly used lui to get the constant.
7242+ SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
7243+ SDValue OrSixteenOne =
7244+ DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
7245+ SDValue ConvertedResult =
7246+ DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
7247+ return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
7248+ }
7249+
7250+ SDValue
7251+ RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
7252+ SelectionDAG &DAG) const {
7253+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
7254+ "Unexpected bfloat16 store lowering");
7255+
7256+ StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
7257+ SDLoc DL(Op);
7258+ SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
7259+ Subtarget.getXLenVT(), ST->getValue());
7260+ return DAG.getTruncStore(
7261+ ST->getChain(), DL, FMV, ST->getBasePtr(),
7262+ EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
7263+ ST->getMemOperand());
7264+ }
7265+
72197266SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
72207267 SelectionDAG &DAG) const {
72217268 switch (Op.getOpcode()) {
@@ -7914,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
79147961 return DAG.getMergeValues({Pair, Chain}, DL);
79157962 }
79167963
7964+ if (VT == MVT::bf16)
7965+ return lowerXAndesBfHCvtBFloat16Load(Op, DAG);
7966+
79177967 // Handle normal vector tuple load.
79187968 if (VT.isRISCVVectorTuple()) {
79197969 SDLoc DL(Op);
@@ -7998,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
79988048 {Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
79998049 Store->getMemOperand());
80008050 }
8051+
8052+ if (VT == MVT::bf16)
8053+ return lowerXAndesBfHCvtBFloat16Store(Op, DAG);
8054+
80018055 // Handle normal vector tuple store.
80028056 if (VT.isRISCVVectorTuple()) {
80038057 SDLoc DL(Op);
0 commit comments