Skip to content

Commit eaf1129

Browse files
committed
[AArch64][SVE] Add custom lowering for bfloat FMUL (with +bf16)
This lowers an SVE FMUL of bf16 using the BFMLAL top/bottom instructions rather than extending to an f32 mul. This does require zeroing the accumulator, but requires fewer extends/unpacking.
1 parent 0f4c8dd commit eaf1129

File tree

3 files changed

+98
-36
lines changed

3 files changed

+98
-36
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1801,6 +1801,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18011801
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
18021802
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
18031803
}
1804+
1805+
if (Subtarget->hasBF16() &&
1806+
(Subtarget->hasSVE() || Subtarget->hasSME())) {
1807+
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16})
1808+
setOperationAction(ISD::FMUL, VT, Custom);
1809+
}
18041810
}
18051811

18061812
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7529,6 +7535,43 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
75297535
EndOfTrmp);
75307536
}
75317537

7538+
SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
7539+
EVT VT = Op.getValueType();
7540+
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
7541+
if (VT.getScalarType() != MVT::bf16 ||
7542+
!(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME())))
7543+
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
7544+
7545+
SDLoc DL(Op);
7546+
SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
7547+
SDValue LHS = Op.getOperand(0);
7548+
SDValue RHS = Op.getOperand(1);
7549+
7550+
auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) {
7551+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7552+
DAG.getConstant(IID, DL, MVT::i32), Ops...);
7553+
};
7554+
7555+
SDValue Pg =
7556+
getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1,
7557+
AArch64SVEPredPattern::all);
7558+
// Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
7559+
// instructions. These result in two f32 vectors, which can be converted back
7560+
// to bf16 with FCVT and FCVNT.
7561+
SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32,
7562+
Zero, LHS, RHS);
7563+
SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT,
7564+
DAG.getPOISON(VT), Pg, BottomF32);
7565+
if (VT == MVT::nxv8bf16) {
7566+
// Note: nxv2bf16 and nxv4bf16 only use even lanes.
7567+
SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32,
7568+
Zero, LHS, RHS);
7569+
return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT,
7570+
BottomBF16, Pg, TopF32);
7571+
}
7572+
return BottomBF16;
7573+
}
7574+
75327575
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
75337576
SelectionDAG &DAG) const {
75347577
LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7603,7 +7646,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
76037646
case ISD::FSUB:
76047647
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
76057648
case ISD::FMUL:
7606-
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
7649+
return LowerFMUL(Op, DAG);
76077650
case ISD::FMA:
76087651
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
76097652
case ISD::FDIV:

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ class AArch64TargetLowering : public TargetLowering {
609609
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
610610
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
611611
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
612+
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
612613

613614
SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
614615
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/sve-bf16-arith.ll

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
2+
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-NONSTREAMING
33
; RUN: llc -mattr=+sve,+bf16,+sve-b16b16 < %s | FileCheck %s --check-prefixes=CHECK,B16B16
4-
; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
4+
; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-STREAMING
55
; RUN: llc -mattr=+sme2,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16
66

77
target triple = "aarch64-unknown-linux-gnu"
@@ -520,64 +520,82 @@ define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x
520520
;
521521

522522
define <vscale x 2 x bfloat> @fmul_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
523-
; NOB16B16-LABEL: fmul_nxv2bf16:
524-
; NOB16B16: // %bb.0:
525-
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
526-
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
527-
; NOB16B16-NEXT: ptrue p0.d
528-
; NOB16B16-NEXT: fmul z0.s, p0/m, z0.s, z1.s
529-
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
530-
; NOB16B16-NEXT: ret
523+
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv2bf16:
524+
; NOB16B16-NONSTREAMING: // %bb.0:
525+
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
526+
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.d
527+
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
528+
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
529+
; NOB16B16-NONSTREAMING-NEXT: ret
531530
;
532531
; B16B16-LABEL: fmul_nxv2bf16:
533532
; B16B16: // %bb.0:
534533
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
535534
; B16B16-NEXT: ret
535+
;
536+
; NOB16B16-STREAMING-LABEL: fmul_nxv2bf16:
537+
; NOB16B16-STREAMING: // %bb.0:
538+
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
539+
; NOB16B16-STREAMING-NEXT: ptrue p0.d
540+
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
541+
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
542+
; NOB16B16-STREAMING-NEXT: ret
536543
%res = fmul <vscale x 2 x bfloat> %a, %b
537544
ret <vscale x 2 x bfloat> %res
538545
}
539546

540547
define <vscale x 4 x bfloat> @fmul_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
541-
; NOB16B16-LABEL: fmul_nxv4bf16:
542-
; NOB16B16: // %bb.0:
543-
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
544-
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
545-
; NOB16B16-NEXT: ptrue p0.s
546-
; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s
547-
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
548-
; NOB16B16-NEXT: ret
548+
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv4bf16:
549+
; NOB16B16-NONSTREAMING: // %bb.0:
550+
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
551+
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s
552+
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
553+
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
554+
; NOB16B16-NONSTREAMING-NEXT: ret
549555
;
550556
; B16B16-LABEL: fmul_nxv4bf16:
551557
; B16B16: // %bb.0:
552558
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
553559
; B16B16-NEXT: ret
560+
;
561+
; NOB16B16-STREAMING-LABEL: fmul_nxv4bf16:
562+
; NOB16B16-STREAMING: // %bb.0:
563+
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
564+
; NOB16B16-STREAMING-NEXT: ptrue p0.s
565+
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
566+
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
567+
; NOB16B16-STREAMING-NEXT: ret
554568
%res = fmul <vscale x 4 x bfloat> %a, %b
555569
ret <vscale x 4 x bfloat> %res
556570
}
557571

558572
define <vscale x 8 x bfloat> @fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
559-
; NOB16B16-LABEL: fmul_nxv8bf16:
560-
; NOB16B16: // %bb.0:
561-
; NOB16B16-NEXT: uunpkhi z2.s, z1.h
562-
; NOB16B16-NEXT: uunpkhi z3.s, z0.h
563-
; NOB16B16-NEXT: uunpklo z1.s, z1.h
564-
; NOB16B16-NEXT: uunpklo z0.s, z0.h
565-
; NOB16B16-NEXT: ptrue p0.s
566-
; NOB16B16-NEXT: lsl z2.s, z2.s, #16
567-
; NOB16B16-NEXT: lsl z3.s, z3.s, #16
568-
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
569-
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
570-
; NOB16B16-NEXT: fmul z2.s, z3.s, z2.s
571-
; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s
572-
; NOB16B16-NEXT: bfcvt z1.h, p0/m, z2.s
573-
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
574-
; NOB16B16-NEXT: uzp1 z0.h, z0.h, z1.h
575-
; NOB16B16-NEXT: ret
573+
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv8bf16:
574+
; NOB16B16-NONSTREAMING: // %bb.0:
575+
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
576+
; NOB16B16-NONSTREAMING-NEXT: movi v3.2d, #0000000000000000
577+
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s
578+
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
579+
; NOB16B16-NONSTREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h
580+
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
581+
; NOB16B16-NONSTREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s
582+
; NOB16B16-NONSTREAMING-NEXT: ret
576583
;
577584
; B16B16-LABEL: fmul_nxv8bf16:
578585
; B16B16: // %bb.0:
579586
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
580587
; B16B16-NEXT: ret
588+
;
589+
; NOB16B16-STREAMING-LABEL: fmul_nxv8bf16:
590+
; NOB16B16-STREAMING: // %bb.0:
591+
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
592+
; NOB16B16-STREAMING-NEXT: mov z3.s, #0 // =0x0
593+
; NOB16B16-STREAMING-NEXT: ptrue p0.s
594+
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
595+
; NOB16B16-STREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h
596+
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
597+
; NOB16B16-STREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s
598+
; NOB16B16-STREAMING-NEXT: ret
581599
%res = fmul <vscale x 8 x bfloat> %a, %b
582600
ret <vscale x 8 x bfloat> %res
583601
}

0 commit comments

Comments
 (0)