Skip to content

Commit ef768e8

Browse files
committed
Fixups
1 parent eaf1129 commit ef768e8

File tree

1 file changed

+40
-31
lines changed

1 file changed

+40
-31
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,17 +1795,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
17951795

17961796
if (!Subtarget->hasSVEB16B16() ||
17971797
!Subtarget->isNonStreamingSVEorSME2Available()) {
1798-
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
1799-
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
1800-
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
1801-
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
1802-
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
1803-
}
1804-
1805-
if (Subtarget->hasBF16() &&
1806-
(Subtarget->hasSVE() || Subtarget->hasSME())) {
1807-
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16})
1798+
for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
1799+
MVT PromotedVT = VT.changeVectorElementType(MVT::f32);
1800+
setOperationPromotedToType(ISD::FADD, VT, PromotedVT);
1801+
setOperationPromotedToType(ISD::FMA, VT, PromotedVT);
1802+
setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT);
1803+
setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT);
1804+
setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT);
1805+
setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT);
1806+
setOperationPromotedToType(ISD::FSUB, VT, PromotedVT);
1807+
1808+
if (Subtarget->hasBF16())
18081809
setOperationAction(ISD::FMUL, VT, Custom);
1810+
else
1811+
setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
18091812
}
18101813
}
18111814

@@ -7536,40 +7539,46 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
75367539
}
75377540

75387541
SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
7542+
SDLoc DL(Op);
75397543
EVT VT = Op.getValueType();
75407544
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
75417545
if (VT.getScalarType() != MVT::bf16 ||
7542-
!(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME())))
7546+
(Subtarget.hasSVEB16B16() &&
7547+
Subtarget.isNonStreamingSVEorSME2Available()))
75437548
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
75447549

7545-
SDLoc DL(Op);
7550+
assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering");
7551+
7552+
auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
7553+
return [&, IID](EVT VT, auto... Ops) {
7554+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7555+
DAG.getConstant(IID, DL, MVT::i32), Ops...);
7556+
};
7557+
};
7558+
7559+
// Create helpers for building intrinsic calls.
7560+
auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
7561+
auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
7562+
auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
7563+
auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
7564+
75467565
SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
75477566
SDValue LHS = Op.getOperand(0);
75487567
SDValue RHS = Op.getOperand(1);
75497568

7550-
auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) {
7551-
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7552-
DAG.getConstant(IID, DL, MVT::i32), Ops...);
7553-
};
7554-
75557569
SDValue Pg =
7556-
getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1,
7557-
AArch64SVEPredPattern::all);
7570+
DAG.getConstant(1, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1);
7571+
75587572
// Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
75597573
// instructions. These result in two f32 vectors, which can be converted back
75607574
// to bf16 with FCVT and FCVNT.
7561-
SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32,
7562-
Zero, LHS, RHS);
7563-
SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT,
7564-
DAG.getPOISON(VT), Pg, BottomF32);
7565-
if (VT == MVT::nxv8bf16) {
7566-
// Note: nxv2bf16 and nxv4bf16 only use even lanes.
7567-
SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32,
7568-
Zero, LHS, RHS);
7569-
return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT,
7570-
BottomBF16, Pg, TopF32);
7571-
}
7572-
return BottomBF16;
7575+
SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
7576+
SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32);
7577+
// Note: nxv2bf16 and nxv4bf16 only use even lanes.
7578+
if (VT != MVT::nxv8bf16)
7579+
return BottomBF16;
7580+
SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
7581+
return FCVNT(VT, BottomBF16, Pg, TopF32);
75737582
}
75747583

75757584
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,

0 commit comments

Comments
 (0)