@@ -804,6 +804,9 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
804804 setAllExpand(MVT::bf16);
805805 if (!Subtarget->hasFullFP16())
806806 setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
807+ } else {
808+ setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
809+ setOperationAction(ISD::FP_TO_BF16, MVT::f32, Custom);
807810 }
808811
809812 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
@@ -6301,10 +6304,13 @@ SDValue ARMTargetLowering::ExpandBITCAST(SDNode *N, SelectionDAG &DAG,
63016304 DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), MVT::i32, Op));
63026305
63036306 if ((DstVT == MVT::i16 || DstVT == MVT::i32) &&
6304- (SrcVT == MVT::f16 || SrcVT == MVT::bf16))
6307+ (SrcVT == MVT::f16 || SrcVT == MVT::bf16)) {
6308+ if (Subtarget->hasFullFP16() && !Subtarget->hasBF16())
6309+ Op = DAG.getBitcast(MVT::f16, Op);
63056310 return DAG.getNode(
63066311 ISD::TRUNCATE, SDLoc(N), DstVT,
63076312 MoveFromHPR(SDLoc(N), DAG, MVT::i32, SrcVT.getSimpleVT(), Op));
6313+ }
63086314
63096315 if (!(SrcVT == MVT::i64 || DstVT == MVT::i64))
63106316 return SDValue();
@@ -10588,6 +10594,17 @@ SDValue ARMTargetLowering::LowerSPONENTRY(SDValue Op, SelectionDAG &DAG) const {
1058810594 return DAG.getFrameIndex(FI, VT);
1058910595}
1059010596
10597+ SDValue ARMTargetLowering::LowerFP_TO_BF16(SDValue Op,
10598+ SelectionDAG &DAG) const {
10599+ SDLoc DL(Op);
10600+ MakeLibCallOptions CallOptions;
10601+ MVT SVT = Op.getOperand(0).getSimpleValueType();
10602+ RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, MVT::bf16);
10603+ SDValue Res =
10604+ makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
10605+ return DAG.getBitcast(MVT::i32, Res);
10606+ }
10607+
1059110608SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1059210609 LLVM_DEBUG(dbgs() << "Lowering node: "; Op.dump());
1059310610 switch (Op.getOpcode()) {
@@ -10713,6 +10730,8 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1071310730 case ISD::STRICT_FSETCCS: return LowerFSETCC(Op, DAG);
1071410731 case ISD::SPONENTRY:
1071510732 return LowerSPONENTRY(Op, DAG);
10733+ case ISD::FP_TO_BF16:
10734+ return LowerFP_TO_BF16(Op, DAG);
1071610735 case ARMISD::WIN__DBZCHK: return SDValue();
1071710736 }
1071810737}
0 commit comments