Skip to content

Commit 2879d78

Browse files
committed
Use custom lowering rather than DAG Combiner
1 parent 9892c86 commit 2879d78

File tree

5 files changed

+127
-86
lines changed

5 files changed

+127
-86
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,8 +1170,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11701170
setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
11711171
setTargetDAGCombine(ISD::CTPOP);
11721172

1173-
setTargetDAGCombine(ISD::FMA);
1174-
11751173
// In case of strict alignment, avoid an excessive number of byte wide stores.
11761174
MaxStoresPerMemsetOptSize = 8;
11771175
MaxStoresPerMemset =
@@ -1526,6 +1524,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15261524

15271525
for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
15281526
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
1527+
1528+
for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64}) {
1529+
setOperationAction(ISD::FMA, VT, Custom);
1530+
}
15291531
}
15301532

15311533
if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -7732,6 +7734,37 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
77327734
return FCVTNT(VT, BottomBF16, Pg, TopF32);
77337735
}
77347736

7737+
SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
7738+
SDValue OpA = Op->getOperand(0);
7739+
SDValue OpB = Op->getOperand(1);
7740+
SDValue OpC = Op->getOperand(2);
7741+
EVT VT = Op.getValueType();
7742+
SDLoc DL(Op);
7743+
7744+
// Bail early if we're definitely not looking to merge FNEGs into the FMA.
7745+
if (!VT.isFixedLengthVector() || OpC.getOpcode() != ISD::FNEG) {
7746+
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
7747+
}
7748+
7749+
// Convert FMA/FNEG nodes to SVE to enable the following patterns:
7750+
// fma(a, b, neg(c)) -> fnmls(a, b, c)
7751+
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
7752+
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
7753+
SDValue Pg = getPredicateForVector(DAG, DL, VT);
7754+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
7755+
7756+
for (SDValue *Op : {&OpA, &OpB, &OpC}) {
7757+
// Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
7758+
*Op = Op->getOpcode() == ISD::FNEG
7759+
? LowerToPredicatedOp(*Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
7760+
->getOperand(0)
7761+
: convertToScalableVector(DAG, ContainerVT, *Op);
7762+
}
7763+
SDValue ScalableRes =
7764+
DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
7765+
return convertFromScalableVector(DAG, VT, ScalableRes);
7766+
}
7767+
77357768
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77367769
SelectionDAG &DAG) const {
77377770
LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7808,7 +7841,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
78087841
case ISD::FMUL:
78097842
return LowerFMUL(Op, DAG);
78107843
case ISD::FMA:
7811-
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
7844+
return LowerFMA(Op, DAG);
78127845
case ISD::FDIV:
78137846
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
78147847
case ISD::FNEG:
@@ -20694,47 +20727,6 @@ static SDValue performFADDCombine(SDNode *N,
2069420727
return SDValue();
2069520728
}
2069620729

20697-
static SDValue performFMACombine(SDNode *N,
20698-
TargetLowering::DAGCombinerInfo &DCI,
20699-
const AArch64Subtarget *Subtarget) {
20700-
SelectionDAG &DAG = DCI.DAG;
20701-
SDValue OpA = N->getOperand(0);
20702-
SDValue OpB = N->getOperand(1);
20703-
SDValue OpC = N->getOperand(2);
20704-
EVT VT = N->getValueType(0);
20705-
SDLoc DL(N);
20706-
20707-
// Convert FMA/FNEG nodes to SVE to enable the following patterns:
20708-
// fma(a, b, neg(c)) -> fnmls(a, b, c)
20709-
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
20710-
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
20711-
if (!VT.isFixedLengthVector() ||
20712-
!DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
20713-
!Subtarget->isSVEorStreamingSVEAvailable() ||
20714-
OpC.getOpcode() != ISD::FNEG) {
20715-
return SDValue();
20716-
}
20717-
20718-
SDValue Pg = getPredicateForVector(DAG, DL, VT);
20719-
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
20720-
OpC =
20721-
DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
20722-
convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0)),
20723-
DAG.getUNDEF(ContainerVT));
20724-
20725-
OpA = OpA.getOpcode() == ISD::FNEG
20726-
? DAG.getNode(
20727-
AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
20728-
convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0)),
20729-
DAG.getUNDEF(ContainerVT))
20730-
: convertToScalableVector(DAG, ContainerVT, OpA);
20731-
20732-
OpB = convertToScalableVector(DAG, ContainerVT, OpB);
20733-
SDValue ScalableRes =
20734-
DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
20735-
return convertFromScalableVector(DAG, VT, ScalableRes);
20736-
}
20737-
2073820730
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
2073920731
switch (Opcode) {
2074020732
case ISD::STRICT_FADD:
@@ -28266,8 +28258,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2826628258
return performANDCombine(N, DCI);
2826728259
case ISD::FADD:
2826828260
return performFADDCombine(N, DCI);
28269-
case ISD::FMA:
28270-
return performFMACombine(N, DCI, Subtarget);
2827128261
case ISD::INTRINSIC_WO_CHAIN:
2827228262
return performIntrinsicCombine(N, DCI, Subtarget);
2827328263
case ISD::ANY_EXTEND:

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ class AArch64TargetLowering : public TargetLowering {
615615
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
616616
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
617617
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
618+
SDValue LowerFMA(SDValue Op, SelectionDAG &DAG) const;
618619

619620
SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
620621
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@ define <4 x double> @simple_symmetric_muladd2(<4 x double> %a, <4 x double> %b)
77
; CHECK-LABEL: simple_symmetric_muladd2:
88
; CHECK: // %bb.0: // %entry
99
; CHECK-NEXT: mov x8, #-7378697629483820647 // =0x9999999999999999
10+
; CHECK-NEXT: ptrue p0.d, vl2
11+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
12+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
13+
; CHECK-NEXT: // kill: def $q3 killed $q3 def $z3
14+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
1015
; CHECK-NEXT: movk x8, #39322
1116
; CHECK-NEXT: movk x8, #16393, lsl #48
1217
; CHECK-NEXT: dup v4.2d, x8
13-
; CHECK-NEXT: fmla v2.2d, v4.2d, v0.2d
14-
; CHECK-NEXT: fmla v3.2d, v4.2d, v1.2d
15-
; CHECK-NEXT: mov v0.16b, v2.16b
16-
; CHECK-NEXT: mov v1.16b, v3.16b
18+
; CHECK-NEXT: fmad z0.d, p0/m, z4.d, z2.d
19+
; CHECK-NEXT: fmad z1.d, p0/m, z4.d, z3.d
20+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
21+
; CHECK-NEXT: // kill: def $q1 killed $q1 killed $z1
1722
; CHECK-NEXT: ret
1823
entry:
1924
%ext00 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
@@ -43,10 +48,11 @@ define <8 x double> @simple_symmetric_muladd4(<8 x double> %a, <8 x double> %b)
4348
; CHECK-NEXT: zip1 v17.2d, v5.2d, v7.2d
4449
; CHECK-NEXT: zip2 v5.2d, v5.2d, v7.2d
4550
; CHECK-NEXT: dup v6.2d, x8
46-
; CHECK-NEXT: fmla v3.2d, v6.2d, v16.2d
47-
; CHECK-NEXT: fmla v4.2d, v6.2d, v0.2d
48-
; CHECK-NEXT: fmla v17.2d, v6.2d, v2.2d
49-
; CHECK-NEXT: fmla v5.2d, v6.2d, v1.2d
51+
; CHECK-NEXT: ptrue p0.d, vl2
52+
; CHECK-NEXT: fmla z3.d, p0/m, z16.d, z6.d
53+
; CHECK-NEXT: fmla z4.d, p0/m, z0.d, z6.d
54+
; CHECK-NEXT: fmla z17.d, p0/m, z2.d, z6.d
55+
; CHECK-NEXT: fmla z5.d, p0/m, z1.d, z6.d
5056
; CHECK-NEXT: zip1 v0.2d, v3.2d, v4.2d
5157
; CHECK-NEXT: zip2 v2.2d, v3.2d, v4.2d
5258
; CHECK-NEXT: zip1 v1.2d, v17.2d, v5.2d

llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -620,8 +620,12 @@ define <4 x half> @fma_v4f16(<4 x half> %op1, <4 x half> %op2, <4 x half> %op3)
620620
define <8 x half> @fma_v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3) vscale_range(2,0) #0 {
621621
; CHECK-LABEL: fma_v8f16:
622622
; CHECK: // %bb.0:
623-
; CHECK-NEXT: fmla v2.8h, v1.8h, v0.8h
624-
; CHECK-NEXT: mov v0.16b, v2.16b
623+
; CHECK-NEXT: ptrue p0.h, vl8
624+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
625+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
626+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
627+
; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h
628+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
625629
; CHECK-NEXT: ret
626630
%res = call <8 x half> @llvm.fma.v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3)
627631
ret <8 x half> %res
@@ -730,8 +734,12 @@ define <2 x float> @fma_v2f32(<2 x float> %op1, <2 x float> %op2, <2 x float> %o
730734
define <4 x float> @fma_v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3) vscale_range(2,0) #0 {
731735
; CHECK-LABEL: fma_v4f32:
732736
; CHECK: // %bb.0:
733-
; CHECK-NEXT: fmla v2.4s, v1.4s, v0.4s
734-
; CHECK-NEXT: mov v0.16b, v2.16b
737+
; CHECK-NEXT: ptrue p0.s, vl4
738+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
739+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
740+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
741+
; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s
742+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
735743
; CHECK-NEXT: ret
736744
%res = call <4 x float> @llvm.fma.v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3)
737745
ret <4 x float> %res
@@ -839,8 +847,12 @@ define <1 x double> @fma_v1f64(<1 x double> %op1, <1 x double> %op2, <1 x double
839847
define <2 x double> @fma_v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3) vscale_range(2,0) #0 {
840848
; CHECK-LABEL: fma_v2f64:
841849
; CHECK: // %bb.0:
842-
; CHECK-NEXT: fmla v2.2d, v1.2d, v0.2d
843-
; CHECK-NEXT: mov v0.16b, v2.16b
850+
; CHECK-NEXT: ptrue p0.d, vl2
851+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
852+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
853+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
854+
; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d
855+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
844856
; CHECK-NEXT: ret
845857
%res = call <2 x double> @llvm.fma.v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3)
846858
ret <2 x double> %res

llvm/test/CodeGen/AArch64/sve-fmsub.ll

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,57 @@ entry:
267267
ret <8 x half> %0
268268
}
269269

270+
define <2 x double> @fnmsub_negated_b_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
271+
; CHECK-LABEL: fnmsub_negated_b_v2f64:
272+
; CHECK: // %bb.0: // %entry
273+
; CHECK-NEXT: ptrue p0.d, vl2
274+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
275+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
276+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
277+
; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d
278+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
279+
; CHECK-NEXT: ret
280+
entry:
281+
%neg = fneg <2 x double> %b
282+
%neg1 = fneg <2 x double> %c
283+
%0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %neg, <2 x double> %neg1)
284+
ret <2 x double> %0
285+
}
286+
287+
define <4 x float> @fnmsub_negated_b_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
288+
; CHECK-LABEL: fnmsub_negated_b_v4f32:
289+
; CHECK: // %bb.0: // %entry
290+
; CHECK-NEXT: ptrue p0.s, vl4
291+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
292+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
293+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
294+
; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
295+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
296+
; CHECK-NEXT: ret
297+
entry:
298+
%neg = fneg <4 x float> %b
299+
%neg1 = fneg <4 x float> %c
300+
%0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %neg, <4 x float> %neg1)
301+
ret <4 x float> %0
302+
}
303+
304+
define <8 x half> @fnmsub_negated_b_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
305+
; CHECK-LABEL: fnmsub_negated_b_v8f16:
306+
; CHECK: // %bb.0: // %entry
307+
; CHECK-NEXT: ptrue p0.h, vl8
308+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
309+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
310+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
311+
; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
312+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
313+
; CHECK-NEXT: ret
314+
entry:
315+
%neg = fneg <8 x half> %b
316+
%neg1 = fneg <8 x half> %c
317+
%0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %neg, <8 x half> %neg1)
318+
ret <8 x half> %0
319+
}
320+
270321
define <2 x double> @fnmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
271322
; CHECK-LABEL: fnmsub_flipped_v2f64:
272323
; CHECK: // %bb.0: // %entry
@@ -333,32 +384,10 @@ entry:
333384
}
334385

335386
define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) {
336-
; CHECK-SVE-LABEL: fmsub_illegal_v1f64:
337-
; CHECK-SVE: // %bb.0: // %entry
338-
; CHECK-SVE-NEXT: ptrue p0.d, vl1
339-
; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 def $z0
340-
; CHECK-SVE-NEXT: // kill: def $d2 killed $d2 def $z2
341-
; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 def $z1
342-
; CHECK-SVE-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
343-
; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $z0
344-
; CHECK-SVE-NEXT: ret
345-
;
346-
; CHECK-SME-LABEL: fmsub_illegal_v1f64:
347-
; CHECK-SME: // %bb.0: // %entry
348-
; CHECK-SME-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
349-
; CHECK-SME-NEXT: addvl sp, sp, #-1
350-
; CHECK-SME-NEXT: .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG
351-
; CHECK-SME-NEXT: .cfi_offset w29, -16
352-
; CHECK-SME-NEXT: ptrue p0.d, vl1
353-
; CHECK-SME-NEXT: // kill: def $d0 killed $d0 def $z0
354-
; CHECK-SME-NEXT: // kill: def $d2 killed $d2 def $z2
355-
; CHECK-SME-NEXT: // kill: def $d1 killed $d1 def $z1
356-
; CHECK-SME-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
357-
; CHECK-SME-NEXT: str z0, [sp]
358-
; CHECK-SME-NEXT: ldr d0, [sp]
359-
; CHECK-SME-NEXT: addvl sp, sp, #1
360-
; CHECK-SME-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
361-
; CHECK-SME-NEXT: ret
387+
; CHECK-LABEL: fmsub_illegal_v1f64:
388+
; CHECK: // %bb.0: // %entry
389+
; CHECK-NEXT: fnmsub d0, d0, d1, d2
390+
; CHECK-NEXT: ret
362391
entry:
363392
%neg = fneg <1 x double> %c
364393
%0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg)
@@ -427,3 +456,6 @@ entry:
427456
%0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1)
428457
ret <7 x half> %0
429458
}
459+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
460+
; CHECK-SME: {{.*}}
461+
; CHECK-SVE: {{.*}}

0 commit comments

Comments
 (0)