Skip to content

Commit 4b65caf

Browse files
authored
[AArch64][SVE] Add custom lowering for bfloat FMUL (with +bf16) (#167502)
This lowers an SVE FMUL of bf16 using the BFMLAL top/bottom instructions rather than extending to an f32 mul. This does require zeroing the accumulator, but requires fewer extends/unpacking.
1 parent 4604762 commit 4b65caf

File tree

4 files changed

+205
-151
lines changed

4 files changed

+205
-151
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,11 +1809,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18091809

18101810
if (!Subtarget->hasSVEB16B16() ||
18111811
!Subtarget->isNonStreamingSVEorSME2Available()) {
1812-
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
1813-
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
1814-
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
1815-
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
1816-
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
1812+
for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
1813+
MVT PromotedVT = VT.changeVectorElementType(MVT::f32);
1814+
setOperationPromotedToType(ISD::FADD, VT, PromotedVT);
1815+
setOperationPromotedToType(ISD::FMA, VT, PromotedVT);
1816+
setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT);
1817+
setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT);
1818+
setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT);
1819+
setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT);
1820+
setOperationPromotedToType(ISD::FSUB, VT, PromotedVT);
1821+
1822+
if (VT != MVT::nxv2bf16 && Subtarget->hasBF16())
1823+
setOperationAction(ISD::FMUL, VT, Custom);
1824+
else
1825+
setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
18171826
}
18181827
}
18191828

@@ -7670,6 +7679,57 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
76707679
EndOfTrmp);
76717680
}
76727681

7682+
SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
7683+
SDLoc DL(Op);
7684+
EVT VT = Op.getValueType();
7685+
if (VT.getScalarType() != MVT::bf16 ||
7686+
(Subtarget->hasSVEB16B16() &&
7687+
Subtarget->isNonStreamingSVEorSME2Available()))
7688+
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
7689+
7690+
assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
7691+
assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");
7692+
7693+
auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
7694+
return [&, IID](EVT VT, auto... Ops) {
7695+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7696+
DAG.getConstant(IID, DL, MVT::i32), Ops...);
7697+
};
7698+
};
7699+
7700+
auto ReinterpretCast = [&](SDValue Value, EVT VT) {
7701+
if (VT == Value.getValueType())
7702+
return Value;
7703+
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
7704+
};
7705+
7706+
// Create helpers for building intrinsic calls.
7707+
auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
7708+
auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
7709+
auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
7710+
auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
7711+
7712+
// All intrinsics expect to operate on full bf16 vector types.
7713+
SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
7714+
SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
7715+
7716+
SDValue Zero =
7717+
DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
7718+
SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
7719+
7720+
// Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
7721+
// instructions. These result in two f32 vectors, which can be converted back
7722+
// to bf16 with FCVT and FCVTNT.
7723+
SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
7724+
SDValue BottomBF16 =
7725+
FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
7726+
// Note: nxv4bf16 only uses even lanes.
7727+
if (VT == MVT::nxv4bf16)
7728+
return ReinterpretCast(BottomBF16, VT);
7729+
SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
7730+
return FCVTNT(VT, BottomBF16, Pg, TopF32);
7731+
}
7732+
76737733
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
76747734
SelectionDAG &DAG) const {
76757735
LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7744,7 +7804,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77447804
case ISD::FSUB:
77457805
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
77467806
case ISD::FMUL:
7747-
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
7807+
return LowerFMUL(Op, DAG);
77487808
case ISD::FMA:
77497809
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
77507810
case ISD::FDIV:

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ class AArch64TargetLowering : public TargetLowering {
614614
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
615615
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
616616
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
617+
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
617618

618619
SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
619620
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/sve-bf16-arith.ll

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
2+
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-NONSTREAMING
33
; RUN: llc -mattr=+sve,+bf16,+sve-b16b16 < %s | FileCheck %s --check-prefixes=CHECK,B16B16
4-
; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
4+
; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-STREAMING
55
; RUN: llc -mattr=+sme2,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16
66

77
target triple = "aarch64-unknown-linux-gnu"
@@ -530,49 +530,80 @@ define <vscale x 2 x bfloat> @fmul_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x
530530
; B16B16-NEXT: ptrue p0.d
531531
; B16B16-NEXT: bfmul z0.h, p0/m, z0.h, z1.h
532532
; B16B16-NEXT: ret
533-
%res = fmul <vscale x 2 x bfloat> %a, %b
533+
%res = fmul nsz <vscale x 2 x bfloat> %a, %b
534534
ret <vscale x 2 x bfloat> %res
535535
}
536536

537537
define <vscale x 4 x bfloat> @fmul_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
538-
; NOB16B16-LABEL: fmul_nxv4bf16:
539-
; NOB16B16: // %bb.0:
540-
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
541-
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
542-
; NOB16B16-NEXT: ptrue p0.s
543-
; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s
544-
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
545-
; NOB16B16-NEXT: ret
538+
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv4bf16:
539+
; NOB16B16-NONSTREAMING: // %bb.0:
540+
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
541+
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s
542+
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
543+
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
544+
; NOB16B16-NONSTREAMING-NEXT: ret
546545
;
547546
; B16B16-LABEL: fmul_nxv4bf16:
548547
; B16B16: // %bb.0:
549548
; B16B16-NEXT: ptrue p0.s
550549
; B16B16-NEXT: bfmul z0.h, p0/m, z0.h, z1.h
551550
; B16B16-NEXT: ret
552-
%res = fmul <vscale x 4 x bfloat> %a, %b
551+
;
552+
; NOB16B16-STREAMING-LABEL: fmul_nxv4bf16:
553+
; NOB16B16-STREAMING: // %bb.0:
554+
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
555+
; NOB16B16-STREAMING-NEXT: ptrue p0.s
556+
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
557+
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
558+
; NOB16B16-STREAMING-NEXT: ret
559+
%res = fmul nsz <vscale x 4 x bfloat> %a, %b
553560
ret <vscale x 4 x bfloat> %res
554561
}
555562

556563
define <vscale x 8 x bfloat> @fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
557-
; NOB16B16-LABEL: fmul_nxv8bf16:
564+
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv8bf16:
565+
; NOB16B16-NONSTREAMING: // %bb.0:
566+
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
567+
; NOB16B16-NONSTREAMING-NEXT: movi v3.2d, #0000000000000000
568+
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s
569+
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
570+
; NOB16B16-NONSTREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h
571+
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
572+
; NOB16B16-NONSTREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s
573+
; NOB16B16-NONSTREAMING-NEXT: ret
574+
;
575+
; B16B16-LABEL: fmul_nxv8bf16:
576+
; B16B16: // %bb.0:
577+
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
578+
; B16B16-NEXT: ret
579+
;
580+
; NOB16B16-STREAMING-LABEL: fmul_nxv8bf16:
581+
; NOB16B16-STREAMING: // %bb.0:
582+
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
583+
; NOB16B16-STREAMING-NEXT: mov z3.s, #0 // =0x0
584+
; NOB16B16-STREAMING-NEXT: ptrue p0.s
585+
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
586+
; NOB16B16-STREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h
587+
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
588+
; NOB16B16-STREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s
589+
; NOB16B16-STREAMING-NEXT: ret
590+
%res = fmul nsz <vscale x 8 x bfloat> %a, %b
591+
ret <vscale x 8 x bfloat> %res
592+
}
593+
594+
define <vscale x 8 x bfloat> @fmul_nxv8bf16_no_nsz(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
595+
; NOB16B16-LABEL: fmul_nxv8bf16_no_nsz:
558596
; NOB16B16: // %bb.0:
559-
; NOB16B16-NEXT: uunpkhi z2.s, z1.h
560-
; NOB16B16-NEXT: uunpkhi z3.s, z0.h
561-
; NOB16B16-NEXT: uunpklo z1.s, z1.h
562-
; NOB16B16-NEXT: uunpklo z0.s, z0.h
597+
; NOB16B16-NEXT: mov z2.s, #0x80000000
598+
; NOB16B16-NEXT: mov z3.s, #0x80000000
563599
; NOB16B16-NEXT: ptrue p0.s
564-
; NOB16B16-NEXT: lsl z2.s, z2.s, #16
565-
; NOB16B16-NEXT: lsl z3.s, z3.s, #16
566-
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
567-
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
568-
; NOB16B16-NEXT: fmul z2.s, z3.s, z2.s
569-
; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s
570-
; NOB16B16-NEXT: bfcvt z1.h, p0/m, z2.s
571-
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
572-
; NOB16B16-NEXT: uzp1 z0.h, z0.h, z1.h
600+
; NOB16B16-NEXT: bfmlalb z2.s, z0.h, z1.h
601+
; NOB16B16-NEXT: bfmlalt z3.s, z0.h, z1.h
602+
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z2.s
603+
; NOB16B16-NEXT: bfcvtnt z0.h, p0/m, z3.s
573604
; NOB16B16-NEXT: ret
574605
;
575-
; B16B16-LABEL: fmul_nxv8bf16:
606+
; B16B16-LABEL: fmul_nxv8bf16_no_nsz:
576607
; B16B16: // %bb.0:
577608
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
578609
; B16B16-NEXT: ret

0 commit comments

Comments
 (0)