Skip to content

Commit ccfe7e0

Browse files
authored
[AArch64] Combine vector FNEG+FMA into FNML[A|S] (llvm#167900)
This allows for FNEG + FMA sequences to be combined into a single operation, with `FNML[A|S]`, `FNMAD`, or `FNMSB` selected depending on the operand order, similarly to how `performSVEMulAddSubCombine` enables generating `ML[A|S]` instructions by converting the ADD/SUB intrinsics to scalable vectors.
1 parent 8ce5d85 commit ccfe7e0

File tree

4 files changed

+370
-1
lines changed

4 files changed

+370
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15231523

15241524
for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
15251525
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
1526+
1527+
for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64})
1528+
setOperationAction(ISD::FMA, VT, Custom);
15261529
}
15271530

15281531
if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -7781,6 +7784,46 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
77817784
return FCVTNT(VT, BottomBF16, Pg, TopF32);
77827785
}
77837786

7787+
SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
7788+
SDValue OpA = Op->getOperand(0);
7789+
SDValue OpB = Op->getOperand(1);
7790+
SDValue OpC = Op->getOperand(2);
7791+
EVT VT = Op.getValueType();
7792+
SDLoc DL(Op);
7793+
7794+
assert(VT.isVector() && "Scalar fma lowering should be handled by patterns");
7795+
7796+
// Bail early if we're definitely not looking to merge FNEGs into the FMA.
7797+
if (VT != MVT::v8f16 && VT != MVT::v4f32 && VT != MVT::v2f64)
7798+
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
7799+
7800+
if (OpC.getOpcode() != ISD::FNEG)
7801+
return useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())
7802+
? LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED)
7803+
: Op; // Fallback to NEON lowering.
7804+
7805+
// Convert FMA/FNEG nodes to SVE to enable the following patterns:
7806+
// fma(a, b, neg(c)) -> fnmls(a, b, c)
7807+
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
7808+
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
7809+
SDValue Pg = getPredicateForVector(DAG, DL, VT);
7810+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
7811+
7812+
auto ConvertToScalableFnegMt = [&](SDValue Op) {
7813+
if (Op.getOpcode() == ISD::FNEG)
7814+
Op = LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
7815+
return convertToScalableVector(DAG, ContainerVT, Op);
7816+
};
7817+
7818+
OpA = ConvertToScalableFnegMt(OpA);
7819+
OpB = ConvertToScalableFnegMt(OpB);
7820+
OpC = ConvertToScalableFnegMt(OpC);
7821+
7822+
SDValue ScalableRes =
7823+
DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
7824+
return convertFromScalableVector(DAG, VT, ScalableRes);
7825+
}
7826+
77847827
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77857828
SelectionDAG &DAG) const {
77867829
LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7857,7 +7900,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
78577900
case ISD::FMUL:
78587901
return LowerFMUL(Op, DAG);
78597902
case ISD::FMA:
7860-
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
7903+
return LowerFMA(Op, DAG);
78617904
case ISD::FDIV:
78627905
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
78637906
case ISD::FNEG:

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ class AArch64TargetLowering : public TargetLowering {
615615
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
616616
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
617617
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
618+
SDValue LowerFMA(SDValue Op, SelectionDAG &DAG) const;
618619

619620
SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
620621
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
462462
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
463463
[(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
464464
(AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
465+
(AArch64fma_p node:$pg, node:$zn, (AArch64fneg_mt node:$pg, node:$zm, (undef)), (AArch64fneg_mt node:$pg, node:$za, (undef))),
465466
(AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
466467

467468
def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),

0 commit comments

Comments
 (0)