Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 66 additions & 6 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1809,11 +1809,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

if (!Subtarget->hasSVEB16B16() ||
!Subtarget->isNonStreamingSVEorSME2Available()) {
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
MVT PromotedVT = VT.changeVectorElementType(MVT::f32);
setOperationPromotedToType(ISD::FADD, VT, PromotedVT);
setOperationPromotedToType(ISD::FMA, VT, PromotedVT);
setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT);
setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT);
setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT);
setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT);
setOperationPromotedToType(ISD::FSUB, VT, PromotedVT);

if (VT != MVT::nxv2bf16 && Subtarget->hasBF16())
setOperationAction(ISD::FMUL, VT, Custom);
else
setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
}
}

Expand Down Expand Up @@ -7641,6 +7650,57 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
EndOfTrmp);
}

SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT VT = Op.getValueType();
if (VT.getScalarType() != MVT::bf16 ||
(Subtarget->hasSVEB16B16() &&
Subtarget->isNonStreamingSVEorSME2Available()))
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);

assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");

auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
return [&, IID](EVT VT, auto... Ops) {
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
DAG.getConstant(IID, DL, MVT::i32), Ops...);
};
};

auto ReinterpretCast = [&](SDValue Value, EVT VT) {
if (VT == Value.getValueType())
return Value;
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
};

// Create helpers for building intrinsic calls.
auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);

// All intrinsics expect to operate on full bf16 vector types.
SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);

SDValue Zero =
DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);

// Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
// instructions. These result in two f32 vectors, which can be converted back
// to bf16 with FCVT and FCVTNT.
SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
SDValue BottomBF16 =
FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
// Note: nxv4bf16 only uses even lanes.
if (VT == MVT::nxv4bf16)
return ReinterpretCast(BottomBF16, VT);
SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
return FCVTNT(VT, BottomBF16, Pg, TopF32);
}

SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
LLVM_DEBUG(dbgs() << "Custom lowering: ");
Expand Down Expand Up @@ -7715,7 +7775,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::FSUB:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
case ISD::FMUL:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
return LowerFMUL(Op, DAG);
case ISD::FMA:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
case ISD::FDIV:
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
Expand Down
85 changes: 58 additions & 27 deletions llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-NONSTREAMING
; RUN: llc -mattr=+sve,+bf16,+sve-b16b16 < %s | FileCheck %s --check-prefixes=CHECK,B16B16
; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-STREAMING
; RUN: llc -mattr=+sme2,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16

target triple = "aarch64-unknown-linux-gnu"
Expand Down Expand Up @@ -527,48 +527,79 @@ define <vscale x 2 x bfloat> @fmul_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x
; B16B16: // %bb.0:
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
; B16B16-NEXT: ret
%res = fmul <vscale x 2 x bfloat> %a, %b
%res = fmul nsz <vscale x 2 x bfloat> %a, %b
ret <vscale x 2 x bfloat> %res
}

define <vscale x 4 x bfloat> @fmul_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
; NOB16B16-LABEL: fmul_nxv4bf16:
; NOB16B16: // %bb.0:
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
; NOB16B16-NEXT: ptrue p0.s
; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
; NOB16B16-NEXT: ret
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv4bf16:
; NOB16B16-NONSTREAMING: // %bb.0:
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
; NOB16B16-NONSTREAMING-NEXT: ret
;
; B16B16-LABEL: fmul_nxv4bf16:
; B16B16: // %bb.0:
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
; B16B16-NEXT: ret
%res = fmul <vscale x 4 x bfloat> %a, %b
;
; NOB16B16-STREAMING-LABEL: fmul_nxv4bf16:
; NOB16B16-STREAMING: // %bb.0:
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
; NOB16B16-STREAMING-NEXT: ptrue p0.s
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
; NOB16B16-STREAMING-NEXT: ret
%res = fmul nsz <vscale x 4 x bfloat> %a, %b
ret <vscale x 4 x bfloat> %res
}

define <vscale x 8 x bfloat> @fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
; NOB16B16-LABEL: fmul_nxv8bf16:
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv8bf16:
; NOB16B16-NONSTREAMING: // %bb.0:
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
; NOB16B16-NONSTREAMING-NEXT: movi v3.2d, #0000000000000000
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
; NOB16B16-NONSTREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
; NOB16B16-NONSTREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s
; NOB16B16-NONSTREAMING-NEXT: ret
;
; B16B16-LABEL: fmul_nxv8bf16:
; B16B16: // %bb.0:
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
; B16B16-NEXT: ret
;
; NOB16B16-STREAMING-LABEL: fmul_nxv8bf16:
; NOB16B16-STREAMING: // %bb.0:
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
; NOB16B16-STREAMING-NEXT: mov z3.s, #0 // =0x0
; NOB16B16-STREAMING-NEXT: ptrue p0.s
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
; NOB16B16-STREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
; NOB16B16-STREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s
; NOB16B16-STREAMING-NEXT: ret
%res = fmul nsz <vscale x 8 x bfloat> %a, %b
ret <vscale x 8 x bfloat> %res
}

define <vscale x 8 x bfloat> @fmul_nxv8bf16_no_nsz(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
; NOB16B16-LABEL: fmul_nxv8bf16_no_nsz:
; NOB16B16: // %bb.0:
; NOB16B16-NEXT: uunpkhi z2.s, z1.h
; NOB16B16-NEXT: uunpkhi z3.s, z0.h
; NOB16B16-NEXT: uunpklo z1.s, z1.h
; NOB16B16-NEXT: uunpklo z0.s, z0.h
; NOB16B16-NEXT: mov z2.s, #0x80000000
; NOB16B16-NEXT: mov z3.s, #0x80000000
; NOB16B16-NEXT: ptrue p0.s
; NOB16B16-NEXT: lsl z2.s, z2.s, #16
; NOB16B16-NEXT: lsl z3.s, z3.s, #16
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
; NOB16B16-NEXT: fmul z2.s, z3.s, z2.s
; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s
; NOB16B16-NEXT: bfcvt z1.h, p0/m, z2.s
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
; NOB16B16-NEXT: uzp1 z0.h, z0.h, z1.h
; NOB16B16-NEXT: bfmlalb z2.s, z0.h, z1.h
; NOB16B16-NEXT: bfmlalt z3.s, z0.h, z1.h
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z2.s
; NOB16B16-NEXT: bfcvtnt z0.h, p0/m, z3.s
; NOB16B16-NEXT: ret
;
; B16B16-LABEL: fmul_nxv8bf16:
; B16B16-LABEL: fmul_nxv8bf16_no_nsz:
; B16B16: // %bb.0:
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
; B16B16-NEXT: ret
Expand Down
Loading