Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 36 additions & 46 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,8 +1170,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
setTargetDAGCombine(ISD::CTPOP);

setTargetDAGCombine(ISD::FMA);

// In case of strict alignment, avoid an excessive number of byte wide stores.
MaxStoresPerMemsetOptSize = 8;
MaxStoresPerMemset =
Expand Down Expand Up @@ -1526,6 +1524,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);

for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64}) {
setOperationAction(ISD::FMA, VT, Custom);
}
}

if (Subtarget->isSVEorStreamingSVEAvailable()) {
Expand Down Expand Up @@ -7732,6 +7734,37 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
return FCVTNT(VT, BottomBF16, Pg, TopF32);
}

SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
SDValue OpA = Op->getOperand(0);
SDValue OpB = Op->getOperand(1);
SDValue OpC = Op->getOperand(2);
EVT VT = Op.getValueType();
SDLoc DL(Op);

// Bail early if we're definitely not looking to merge FNEGs into the FMA.
if (!VT.isFixedLengthVector() || OpC.getOpcode() != ISD::FNEG) {
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
}

// Convert FMA/FNEG nodes to SVE to enable the following patterns:
// fma(a, b, neg(c)) -> fnmls(a, b, c)
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
SDValue Pg = getPredicateForVector(DAG, DL, VT);
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);

for (SDValue *Op : {&OpA, &OpB, &OpC}) {
// Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
*Op = Op->getOpcode() == ISD::FNEG
? LowerToPredicatedOp(*Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
->getOperand(0)
: convertToScalableVector(DAG, ContainerVT, *Op);
}
SDValue ScalableRes =
DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
return convertFromScalableVector(DAG, VT, ScalableRes);
}

SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
LLVM_DEBUG(dbgs() << "Custom lowering: ");
Expand Down Expand Up @@ -7808,7 +7841,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::FMUL:
return LowerFMUL(Op, DAG);
case ISD::FMA:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
return LowerFMA(Op, DAG);
case ISD::FDIV:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
case ISD::FNEG:
Expand Down Expand Up @@ -20694,47 +20727,6 @@ static SDValue performFADDCombine(SDNode *N,
return SDValue();
}

static SDValue performFMACombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
SelectionDAG &DAG = DCI.DAG;
SDValue OpA = N->getOperand(0);
SDValue OpB = N->getOperand(1);
SDValue OpC = N->getOperand(2);
EVT VT = N->getValueType(0);
SDLoc DL(N);

// Convert FMA/FNEG nodes to SVE to enable the following patterns:
// fma(a, b, neg(c)) -> fnmls(a, b, c)
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
if (!VT.isFixedLengthVector() ||
!DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
!Subtarget->isSVEorStreamingSVEAvailable() ||
OpC.getOpcode() != ISD::FNEG) {
return SDValue();
}

SDValue Pg = getPredicateForVector(DAG, DL, VT);
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
OpC =
DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0)),
DAG.getUNDEF(ContainerVT));

OpA = OpA.getOpcode() == ISD::FNEG
? DAG.getNode(
AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0)),
DAG.getUNDEF(ContainerVT))
: convertToScalableVector(DAG, ContainerVT, OpA);

OpB = convertToScalableVector(DAG, ContainerVT, OpB);
SDValue ScalableRes =
DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
return convertFromScalableVector(DAG, VT, ScalableRes);
}

static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
switch (Opcode) {
case ISD::STRICT_FADD:
Expand Down Expand Up @@ -28266,8 +28258,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performANDCombine(N, DCI);
case ISD::FADD:
return performFADDCombine(N, DCI);
case ISD::FMA:
return performFMACombine(N, DCI, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
return performIntrinsicCombine(N, DCI, Subtarget);
case ISD::ANY_EXTEND:
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFMA(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@ define <4 x double> @simple_symmetric_muladd2(<4 x double> %a, <4 x double> %b)
; CHECK-LABEL: simple_symmetric_muladd2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov x8, #-7378697629483820647 // =0x9999999999999999
; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q3 killed $q3 def $z3
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: movk x8, #39322
; CHECK-NEXT: movk x8, #16393, lsl #48
; CHECK-NEXT: dup v4.2d, x8
; CHECK-NEXT: fmla v2.2d, v4.2d, v0.2d
; CHECK-NEXT: fmla v3.2d, v4.2d, v1.2d
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: mov v1.16b, v3.16b
; CHECK-NEXT: fmad z0.d, p0/m, z4.d, z2.d
; CHECK-NEXT: fmad z1.d, p0/m, z4.d, z3.d
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: // kill: def $q1 killed $q1 killed $z1
; CHECK-NEXT: ret
entry:
%ext00 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
Expand Down Expand Up @@ -43,10 +48,11 @@ define <8 x double> @simple_symmetric_muladd4(<8 x double> %a, <8 x double> %b)
; CHECK-NEXT: zip1 v17.2d, v5.2d, v7.2d
; CHECK-NEXT: zip2 v5.2d, v5.2d, v7.2d
; CHECK-NEXT: dup v6.2d, x8
; CHECK-NEXT: fmla v3.2d, v6.2d, v16.2d
; CHECK-NEXT: fmla v4.2d, v6.2d, v0.2d
; CHECK-NEXT: fmla v17.2d, v6.2d, v2.2d
; CHECK-NEXT: fmla v5.2d, v6.2d, v1.2d
; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: fmla z3.d, p0/m, z16.d, z6.d
; CHECK-NEXT: fmla z4.d, p0/m, z0.d, z6.d
; CHECK-NEXT: fmla z17.d, p0/m, z2.d, z6.d
; CHECK-NEXT: fmla z5.d, p0/m, z1.d, z6.d
; CHECK-NEXT: zip1 v0.2d, v3.2d, v4.2d
; CHECK-NEXT: zip2 v2.2d, v3.2d, v4.2d
; CHECK-NEXT: zip1 v1.2d, v17.2d, v5.2d
Expand Down
24 changes: 18 additions & 6 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,12 @@ define <4 x half> @fma_v4f16(<4 x half> %op1, <4 x half> %op2, <4 x half> %op3)
define <8 x half> @fma_v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v8f16:
; CHECK: // %bb.0:
; CHECK-NEXT: fmla v2.8h, v1.8h, v0.8h
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ptrue p0.h, vl8
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%res = call <8 x half> @llvm.fma.v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3)
ret <8 x half> %res
Expand Down Expand Up @@ -730,8 +734,12 @@ define <2 x float> @fma_v2f32(<2 x float> %op1, <2 x float> %op2, <2 x float> %o
define <4 x float> @fma_v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v4f32:
; CHECK: // %bb.0:
; CHECK-NEXT: fmla v2.4s, v1.4s, v0.4s
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ptrue p0.s, vl4
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%res = call <4 x float> @llvm.fma.v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3)
ret <4 x float> %res
Expand Down Expand Up @@ -839,8 +847,12 @@ define <1 x double> @fma_v1f64(<1 x double> %op1, <1 x double> %op2, <1 x double
define <2 x double> @fma_v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v2f64:
; CHECK: // %bb.0:
; CHECK-NEXT: fmla v2.2d, v1.2d, v0.2d
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%res = call <2 x double> @llvm.fma.v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3)
ret <2 x double> %res
Expand Down
84 changes: 58 additions & 26 deletions llvm/test/CodeGen/AArch64/sve-fmsub.ll
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,57 @@ entry:
ret <8 x half> %0
}

define <2 x double> @fnmsub_negated_b_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
; CHECK-LABEL: fnmsub_negated_b_v2f64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
entry:
%neg = fneg <2 x double> %b
%neg1 = fneg <2 x double> %c
%0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %neg, <2 x double> %neg1)
ret <2 x double> %0
}

define <4 x float> @fnmsub_negated_b_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
; CHECK-LABEL: fnmsub_negated_b_v4f32:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ptrue p0.s, vl4
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
entry:
%neg = fneg <4 x float> %b
%neg1 = fneg <4 x float> %c
%0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %neg, <4 x float> %neg1)
ret <4 x float> %0
}

define <8 x half> @fnmsub_negated_b_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
; CHECK-LABEL: fnmsub_negated_b_v8f16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ptrue p0.h, vl8
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
entry:
%neg = fneg <8 x half> %b
%neg1 = fneg <8 x half> %c
%0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %neg, <8 x half> %neg1)
ret <8 x half> %0
}

define <2 x double> @fnmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
; CHECK-LABEL: fnmsub_flipped_v2f64:
; CHECK: // %bb.0: // %entry
Expand Down Expand Up @@ -333,32 +384,10 @@ entry:
}

define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) {
; CHECK-SVE-LABEL: fmsub_illegal_v1f64:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: ptrue p0.d, vl1
; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 def $z0
; CHECK-SVE-NEXT: // kill: def $d2 killed $d2 def $z2
; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 def $z1
; CHECK-SVE-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-SVE-NEXT: ret
;
; CHECK-SME-LABEL: fmsub_illegal_v1f64:
; CHECK-SME: // %bb.0: // %entry
; CHECK-SME-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-SME-NEXT: addvl sp, sp, #-1
; CHECK-SME-NEXT: .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-SME-NEXT: .cfi_offset w29, -16
; CHECK-SME-NEXT: ptrue p0.d, vl1
; CHECK-SME-NEXT: // kill: def $d0 killed $d0 def $z0
; CHECK-SME-NEXT: // kill: def $d2 killed $d2 def $z2
; CHECK-SME-NEXT: // kill: def $d1 killed $d1 def $z1
; CHECK-SME-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
; CHECK-SME-NEXT: str z0, [sp]
; CHECK-SME-NEXT: ldr d0, [sp]
; CHECK-SME-NEXT: addvl sp, sp, #1
; CHECK-SME-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-SME-NEXT: ret
; CHECK-LABEL: fmsub_illegal_v1f64:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: fnmsub d0, d0, d1, d2
; CHECK-NEXT: ret
entry:
%neg = fneg <1 x double> %c
%0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg)
Expand Down Expand Up @@ -427,3 +456,6 @@ entry:
%0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1)
ret <7 x half> %0
}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; CHECK-SME: {{.*}}
; CHECK-SVE: {{.*}}