Skip to content

Commit e020f46

Browse files
committed
[ARM] Fix BF16 lowering with FullFP16
This adds test coverage for bf16 instructions, making sure that lowering bf16 works with and without +fullfp16.
1 parent eaf482f commit e020f46

File tree

3 files changed

+2357
-1
lines changed

3 files changed

+2357
-1
lines changed

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,9 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
804804
setAllExpand(MVT::bf16);
805805
if (!Subtarget->hasFullFP16())
806806
setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
807+
} else {
808+
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
809+
setOperationAction(ISD::FP_TO_BF16, MVT::f32, Custom);
807810
}
808811

809812
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
@@ -6301,10 +6304,13 @@ SDValue ARMTargetLowering::ExpandBITCAST(SDNode *N, SelectionDAG &DAG,
63016304
DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), MVT::i32, Op));
63026305

63036306
if ((DstVT == MVT::i16 || DstVT == MVT::i32) &&
6304-
(SrcVT == MVT::f16 || SrcVT == MVT::bf16))
6307+
(SrcVT == MVT::f16 || SrcVT == MVT::bf16)) {
6308+
if (Subtarget->hasFullFP16() && !Subtarget->hasBF16())
6309+
Op = DAG.getBitcast(MVT::f16, Op);
63056310
return DAG.getNode(
63066311
ISD::TRUNCATE, SDLoc(N), DstVT,
63076312
MoveFromHPR(SDLoc(N), DAG, MVT::i32, SrcVT.getSimpleVT(), Op));
6313+
}
63086314

63096315
if (!(SrcVT == MVT::i64 || DstVT == MVT::i64))
63106316
return SDValue();
@@ -10588,6 +10594,17 @@ SDValue ARMTargetLowering::LowerSPONENTRY(SDValue Op, SelectionDAG &DAG) const {
1058810594
return DAG.getFrameIndex(FI, VT);
1058910595
}
1059010596

10597+
SDValue ARMTargetLowering::LowerFP_TO_BF16(SDValue Op,
10598+
SelectionDAG &DAG) const {
10599+
SDLoc DL(Op);
10600+
MakeLibCallOptions CallOptions;
10601+
MVT SVT = Op.getOperand(0).getSimpleValueType();
10602+
RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, MVT::bf16);
10603+
SDValue Res =
10604+
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
10605+
return DAG.getBitcast(MVT::i32, Res);
10606+
}
10607+
1059110608
SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1059210609
LLVM_DEBUG(dbgs() << "Lowering node: "; Op.dump());
1059310610
switch (Op.getOpcode()) {
@@ -10713,6 +10730,8 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1071310730
case ISD::STRICT_FSETCCS: return LowerFSETCC(Op, DAG);
1071410731
case ISD::SPONENTRY:
1071510732
return LowerSPONENTRY(Op, DAG);
10733+
case ISD::FP_TO_BF16:
10734+
return LowerFP_TO_BF16(Op, DAG);
1071610735
case ARMISD::WIN__DBZCHK: return SDValue();
1071710736
}
1071810737
}

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,7 @@ class VectorType;
886886
SDValue LowerSPONENTRY(SDValue Op, SelectionDAG &DAG) const;
887887
void LowerLOAD(SDNode *N, SmallVectorImpl<SDValue> &Results,
888888
SelectionDAG &DAG) const;
889+
SDValue LowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
889890

890891
Register getRegisterByName(const char* RegName, LLT VT,
891892
const MachineFunction &MF) const override;

0 commit comments

Comments
 (0)