Skip to content

Commit a0925e8

Browse files
committed
[AArch64][SVE] Add custom lowering for bfloat FMUL (with +bf16)
This lowers an SVE FMUL of bf16 using the BFMLAL top/bottom instructions rather than extending to an f32 mul. This does require zeroing the accumulator, but requires fewer extends/unpacking.
1 parent 1122581 commit a0925e8

File tree

3 files changed

+98
-36
lines changed

3 files changed

+98
-36
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1803,6 +1803,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18031803
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
18041804
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
18051805
}
1806+
1807+
if (Subtarget->hasBF16() &&
1808+
(Subtarget->hasSVE() || Subtarget->hasSME())) {
1809+
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16})
1810+
setOperationAction(ISD::FMUL, VT, Custom);
1811+
}
18061812
}
18071813

18081814
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7538,6 +7544,43 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
75387544
EndOfTrmp);
75397545
}
75407546

7547+
SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
7548+
EVT VT = Op.getValueType();
7549+
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
7550+
if (VT.getScalarType() != MVT::bf16 ||
7551+
!(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME())))
7552+
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
7553+
7554+
SDLoc DL(Op);
7555+
SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
7556+
SDValue LHS = Op.getOperand(0);
7557+
SDValue RHS = Op.getOperand(1);
7558+
7559+
auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) {
7560+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7561+
DAG.getConstant(IID, DL, MVT::i32), Ops...);
7562+
};
7563+
7564+
SDValue Pg =
7565+
getPTrue(DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1,
7566+
AArch64SVEPredPattern::all);
7567+
// Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
7568+
// instructions. These result in two f32 vectors, which can be converted back
7569+
// to bf16 with FCVT and FCVNT.
7570+
SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32,
7571+
Zero, LHS, RHS);
7572+
SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT,
7573+
DAG.getPOISON(VT), Pg, BottomF32);
7574+
if (VT == MVT::nxv8bf16) {
7575+
// Note: nxv2bf16 and nxv4bf16 only use even lanes.
7576+
SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32,
7577+
Zero, LHS, RHS);
7578+
return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT,
7579+
BottomBF16, Pg, TopF32);
7580+
}
7581+
return BottomBF16;
7582+
}
7583+
75417584
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
75427585
SelectionDAG &DAG) const {
75437586
LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7612,7 +7655,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
76127655
case ISD::FSUB:
76137656
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
76147657
case ISD::FMUL:
7615-
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
7658+
return LowerFMUL(Op, DAG);
76167659
case ISD::FMA:
76177660
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
76187661
case ISD::FDIV:

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ class AArch64TargetLowering : public TargetLowering {
609609
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
610610
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
611611
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
612+
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
612613

613614
SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
614615
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/sve-bf16-arith.ll

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
2+
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-NONSTREAMING
33
; RUN: llc -mattr=+sve,+bf16,+sve-b16b16 < %s | FileCheck %s --check-prefixes=CHECK,B16B16
4-
; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
4+
; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16,NOB16B16-STREAMING
55
; RUN: llc -mattr=+sme2,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16
66

77
target triple = "aarch64-unknown-linux-gnu"
@@ -514,64 +514,82 @@ define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x
514514
;
515515

516516
define <vscale x 2 x bfloat> @fmul_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
517-
; NOB16B16-LABEL: fmul_nxv2bf16:
518-
; NOB16B16: // %bb.0:
519-
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
520-
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
521-
; NOB16B16-NEXT: ptrue p0.d
522-
; NOB16B16-NEXT: fmul z0.s, p0/m, z0.s, z1.s
523-
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
524-
; NOB16B16-NEXT: ret
517+
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv2bf16:
518+
; NOB16B16-NONSTREAMING: // %bb.0:
519+
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
520+
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.d
521+
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
522+
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
523+
; NOB16B16-NONSTREAMING-NEXT: ret
525524
;
526525
; B16B16-LABEL: fmul_nxv2bf16:
527526
; B16B16: // %bb.0:
528527
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
529528
; B16B16-NEXT: ret
529+
;
530+
; NOB16B16-STREAMING-LABEL: fmul_nxv2bf16:
531+
; NOB16B16-STREAMING: // %bb.0:
532+
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
533+
; NOB16B16-STREAMING-NEXT: ptrue p0.d
534+
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
535+
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
536+
; NOB16B16-STREAMING-NEXT: ret
530537
%res = fmul <vscale x 2 x bfloat> %a, %b
531538
ret <vscale x 2 x bfloat> %res
532539
}
533540

534541
define <vscale x 4 x bfloat> @fmul_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
535-
; NOB16B16-LABEL: fmul_nxv4bf16:
536-
; NOB16B16: // %bb.0:
537-
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
538-
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
539-
; NOB16B16-NEXT: ptrue p0.s
540-
; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s
541-
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
542-
; NOB16B16-NEXT: ret
542+
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv4bf16:
543+
; NOB16B16-NONSTREAMING: // %bb.0:
544+
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
545+
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s
546+
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
547+
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
548+
; NOB16B16-NONSTREAMING-NEXT: ret
543549
;
544550
; B16B16-LABEL: fmul_nxv4bf16:
545551
; B16B16: // %bb.0:
546552
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
547553
; B16B16-NEXT: ret
554+
;
555+
; NOB16B16-STREAMING-LABEL: fmul_nxv4bf16:
556+
; NOB16B16-STREAMING: // %bb.0:
557+
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
558+
; NOB16B16-STREAMING-NEXT: ptrue p0.s
559+
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
560+
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
561+
; NOB16B16-STREAMING-NEXT: ret
548562
%res = fmul <vscale x 4 x bfloat> %a, %b
549563
ret <vscale x 4 x bfloat> %res
550564
}
551565

552566
define <vscale x 8 x bfloat> @fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
553-
; NOB16B16-LABEL: fmul_nxv8bf16:
554-
; NOB16B16: // %bb.0:
555-
; NOB16B16-NEXT: uunpkhi z2.s, z1.h
556-
; NOB16B16-NEXT: uunpkhi z3.s, z0.h
557-
; NOB16B16-NEXT: uunpklo z1.s, z1.h
558-
; NOB16B16-NEXT: uunpklo z0.s, z0.h
559-
; NOB16B16-NEXT: ptrue p0.s
560-
; NOB16B16-NEXT: lsl z2.s, z2.s, #16
561-
; NOB16B16-NEXT: lsl z3.s, z3.s, #16
562-
; NOB16B16-NEXT: lsl z1.s, z1.s, #16
563-
; NOB16B16-NEXT: lsl z0.s, z0.s, #16
564-
; NOB16B16-NEXT: fmul z2.s, z3.s, z2.s
565-
; NOB16B16-NEXT: fmul z0.s, z0.s, z1.s
566-
; NOB16B16-NEXT: bfcvt z1.h, p0/m, z2.s
567-
; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
568-
; NOB16B16-NEXT: uzp1 z0.h, z0.h, z1.h
569-
; NOB16B16-NEXT: ret
567+
; NOB16B16-NONSTREAMING-LABEL: fmul_nxv8bf16:
568+
; NOB16B16-NONSTREAMING: // %bb.0:
569+
; NOB16B16-NONSTREAMING-NEXT: movi v2.2d, #0000000000000000
570+
; NOB16B16-NONSTREAMING-NEXT: movi v3.2d, #0000000000000000
571+
; NOB16B16-NONSTREAMING-NEXT: ptrue p0.s
572+
; NOB16B16-NONSTREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
573+
; NOB16B16-NONSTREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h
574+
; NOB16B16-NONSTREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
575+
; NOB16B16-NONSTREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s
576+
; NOB16B16-NONSTREAMING-NEXT: ret
570577
;
571578
; B16B16-LABEL: fmul_nxv8bf16:
572579
; B16B16: // %bb.0:
573580
; B16B16-NEXT: bfmul z0.h, z0.h, z1.h
574581
; B16B16-NEXT: ret
582+
;
583+
; NOB16B16-STREAMING-LABEL: fmul_nxv8bf16:
584+
; NOB16B16-STREAMING: // %bb.0:
585+
; NOB16B16-STREAMING-NEXT: mov z2.s, #0 // =0x0
586+
; NOB16B16-STREAMING-NEXT: mov z3.s, #0 // =0x0
587+
; NOB16B16-STREAMING-NEXT: ptrue p0.s
588+
; NOB16B16-STREAMING-NEXT: bfmlalb z2.s, z0.h, z1.h
589+
; NOB16B16-STREAMING-NEXT: bfmlalt z3.s, z0.h, z1.h
590+
; NOB16B16-STREAMING-NEXT: bfcvt z0.h, p0/m, z2.s
591+
; NOB16B16-STREAMING-NEXT: bfcvtnt z0.h, p0/m, z3.s
592+
; NOB16B16-STREAMING-NEXT: ret
575593
%res = fmul <vscale x 8 x bfloat> %a, %b
576594
ret <vscale x 8 x bfloat> %res
577595
}

0 commit comments

Comments
 (0)