Skip to content

Commit d4e4360

Browse files
committed
Address review comments
1 parent 9aa8dd5 commit d4e4360

File tree

3 files changed

+143
-37
lines changed

3 files changed

+143
-37
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20450,46 +20450,41 @@ static SDValue performFMACombine(SDNode *N,
2045020450
TargetLowering::DAGCombinerInfo &DCI,
2045120451
const AArch64Subtarget *Subtarget) {
2045220452
SelectionDAG &DAG = DCI.DAG;
20453-
SDValue Op1 = N->getOperand(0);
20454-
SDValue Op2 = N->getOperand(1);
20455-
SDValue Op3 = N->getOperand(2);
20453+
SDValue OpA = N->getOperand(0);
20454+
SDValue OpB = N->getOperand(1);
20455+
SDValue OpC = N->getOperand(2);
2045620456
EVT VT = N->getValueType(0);
2045720457
SDLoc DL(N);
2045820458

2045920459
// fma(a, b, neg(c)) -> fnmls(a, b, c)
2046020460
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
2046120461
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
20462-
if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
20463-
(Subtarget->hasSVE() || Subtarget->hasSME())) {
20464-
if (Op3.getOpcode() == ISD::FNEG) {
20465-
unsigned int Opcode;
20466-
if (Op1.getOpcode() == ISD::FNEG) {
20467-
Op1 = Op1.getOperand(0);
20468-
Opcode = AArch64ISD::FNMLA_PRED;
20469-
} else if (Op2.getOpcode() == ISD::FNEG) {
20470-
Op2 = Op2.getOperand(0);
20471-
Opcode = AArch64ISD::FNMLA_PRED;
20472-
} else {
20473-
Opcode = AArch64ISD::FNMLS_PRED;
20474-
}
20475-
Op3 = Op3.getOperand(0);
20476-
auto Pg = getPredicateForVector(DAG, DL, VT);
20477-
if (VT.isFixedLengthVector()) {
20478-
assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
20479-
"Expected only legal fixed-width types");
20480-
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
20481-
Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
20482-
Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
20483-
Op3 = convertToScalableVector(DAG, ContainerVT, Op3);
20484-
auto ScalableRes =
20485-
DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3);
20486-
return convertFromScalableVector(DAG, VT, ScalableRes);
20487-
}
20488-
return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3);
20489-
}
20462+
if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
20463+
!Subtarget->isSVEorStreamingSVEAvailable() ||
20464+
OpC.getOpcode() != ISD::FNEG) {
20465+
return SDValue();
20466+
}
20467+
unsigned int Opcode;
20468+
if (OpA.getOpcode() == ISD::FNEG) {
20469+
OpA = OpA.getOperand(0);
20470+
Opcode = AArch64ISD::FNMLA_PRED;
20471+
} else if (OpB.getOpcode() == ISD::FNEG) {
20472+
OpB = OpB.getOperand(0);
20473+
Opcode = AArch64ISD::FNMLA_PRED;
20474+
} else {
20475+
Opcode = AArch64ISD::FNMLS_PRED;
2049020476
}
20491-
20492-
return SDValue();
20477+
OpC = OpC.getOperand(0);
20478+
auto Pg = getPredicateForVector(DAG, DL, VT);
20479+
if (VT.isFixedLengthVector()) {
20480+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
20481+
OpA = convertToScalableVector(DAG, ContainerVT, OpA);
20482+
OpB = convertToScalableVector(DAG, ContainerVT, OpB);
20483+
OpC = convertToScalableVector(DAG, ContainerVT, OpC);
20484+
auto ScalableRes = DAG.getNode(Opcode, DL, ContainerVT, Pg, OpA, OpB, OpC);
20485+
return convertFromScalableVector(DAG, VT, ScalableRes);
20486+
}
20487+
return DAG.getNode(Opcode, DL, VT, Pg, OpA, OpB, OpC);
2049320488
}
2049420489

2049520490
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,11 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
464464
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
465465
[(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
466466
(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
467-
(AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
468467
(AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
469468

470469
def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
471470
[(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
472-
(int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
473-
(AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
471+
(int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm)]>;
474472

475473
def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
476474
(AArch64fsub_p node:$pg, node:$op2, node:$op1)>;

llvm/test/CodeGen/AArch64/sve-fmsub.ll

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2-
; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK
2+
; RUN: llc -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
3+
; RUN: llc -mattr=+sme -force-streaming %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SME
4+
5+
target triple = "aarch64"
36

47
define <vscale x 2 x double> @fmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
58
; CHECK-LABEL: fmsub_nxv2f64:
@@ -274,3 +277,113 @@ entry:
274277
%0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
275278
ret <8 x half> %0
276279
}
280+
281+
; Illegal types
282+
283+
define <vscale x 3 x float> @fmsub_illegal_nxv3f32(<vscale x 3 x float> %a, <vscale x 3 x float> %b, <vscale x 3 x float> %c) {
284+
; CHECK-LABEL: fmsub_illegal_nxv3f32:
285+
; CHECK: // %bb.0: // %entry
286+
; CHECK-NEXT: ptrue p0.s
287+
; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s
288+
; CHECK-NEXT: ret
289+
entry:
290+
%neg = fneg <vscale x 3 x float> %c
291+
%0 = tail call <vscale x 3 x float> @llvm.fmuladd(<vscale x 3 x float> %a, <vscale x 3 x float> %b, <vscale x 3 x float> %neg)
292+
ret <vscale x 3 x float> %0
293+
}
294+
295+
define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) {
296+
; CHECK-SVE-LABEL: fmsub_illegal_v1f64:
297+
; CHECK-SVE: // %bb.0: // %entry
298+
; CHECK-SVE-NEXT: ptrue p0.d, vl1
299+
; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 def $z0
300+
; CHECK-SVE-NEXT: // kill: def $d2 killed $d2 def $z2
301+
; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 def $z1
302+
; CHECK-SVE-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
303+
; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $z0
304+
; CHECK-SVE-NEXT: ret
305+
;
306+
; CHECK-SME-LABEL: fmsub_illegal_v1f64:
307+
; CHECK-SME: // %bb.0: // %entry
308+
; CHECK-SME-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
309+
; CHECK-SME-NEXT: addvl sp, sp, #-1
310+
; CHECK-SME-NEXT: .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG
311+
; CHECK-SME-NEXT: .cfi_offset w29, -16
312+
; CHECK-SME-NEXT: ptrue p0.d, vl1
313+
; CHECK-SME-NEXT: // kill: def $d0 killed $d0 def $z0
314+
; CHECK-SME-NEXT: // kill: def $d2 killed $d2 def $z2
315+
; CHECK-SME-NEXT: // kill: def $d1 killed $d1 def $z1
316+
; CHECK-SME-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
317+
; CHECK-SME-NEXT: str z0, [sp]
318+
; CHECK-SME-NEXT: ldr d0, [sp]
319+
; CHECK-SME-NEXT: addvl sp, sp, #1
320+
; CHECK-SME-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
321+
; CHECK-SME-NEXT: ret
322+
entry:
323+
%neg = fneg <1 x double> %c
324+
%0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg)
325+
ret <1 x double> %0
326+
}
327+
328+
define <3 x float> @fmsub_flipped_illegal_v3f32(<3 x float> %c, <3 x float> %a, <3 x float> %b) {
329+
; CHECK-LABEL: fmsub_flipped_illegal_v3f32:
330+
; CHECK: // %bb.0: // %entry
331+
; CHECK-NEXT: ptrue p0.s, vl4
332+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
333+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
334+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
335+
; CHECK-NEXT: fnmls z0.s, p0/m, z1.s, z2.s
336+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
337+
; CHECK-NEXT: ret
338+
entry:
339+
%neg = fneg <3 x float> %c
340+
%0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %a, <3 x float> %b, <3 x float> %neg)
341+
ret <3 x float> %0
342+
}
343+
344+
define <vscale x 7 x half> @fnmsub_illegal_nxv7f16(<vscale x 7 x half> %a, <vscale x 7 x half> %b, <vscale x 7 x half> %c) {
345+
; CHECK-LABEL: fnmsub_illegal_nxv7f16:
346+
; CHECK: // %bb.0: // %entry
347+
; CHECK-NEXT: ptrue p0.h
348+
; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
349+
; CHECK-NEXT: ret
350+
entry:
351+
%neg = fneg <vscale x 7 x half> %a
352+
%neg1 = fneg <vscale x 7 x half> %c
353+
%0 = tail call <vscale x 7 x half> @llvm.fmuladd(<vscale x 7 x half> %neg, <vscale x 7 x half> %b, <vscale x 7 x half> %neg1)
354+
ret <vscale x 7 x half> %0
355+
}
356+
357+
define <3 x float> @fnmsub_illegal_v3f32(<3 x float> %a, <3 x float> %b, <3 x float> %c) {
358+
; CHECK-LABEL: fnmsub_illegal_v3f32:
359+
; CHECK: // %bb.0: // %entry
360+
; CHECK-NEXT: ptrue p0.s, vl4
361+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
362+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
363+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
364+
; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
365+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
366+
; CHECK-NEXT: ret
367+
entry:
368+
%neg = fneg <3 x float> %a
369+
%neg1 = fneg <3 x float> %c
370+
%0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %neg, <3 x float> %b, <3 x float> %neg1)
371+
ret <3 x float> %0
372+
}
373+
374+
define <7 x half> @fnmsub_flipped_illegal_v7f16(<7 x half> %c, <7 x half> %a, <7 x half> %b) {
375+
; CHECK-LABEL: fnmsub_flipped_illegal_v7f16:
376+
; CHECK: // %bb.0: // %entry
377+
; CHECK-NEXT: ptrue p0.h, vl8
378+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
379+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
380+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
381+
; CHECK-NEXT: fnmla z0.h, p0/m, z1.h, z2.h
382+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
383+
; CHECK-NEXT: ret
384+
entry:
385+
%neg = fneg <7 x half> %a
386+
%neg1 = fneg <7 x half> %c
387+
%0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1)
388+
ret <7 x half> %0
389+
}

0 commit comments

Comments
 (0)