Skip to content

Commit 6e40a1e

Browse files
committed
Fix up neon 8to64 cases
1 parent 3a1ce17 commit 6e40a1e

File tree

2 files changed

+345
-158
lines changed

2 files changed

+345
-158
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18721872
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
18731873
}
18741874

1875+
if (EnablePartialReduceNodes && Subtarget->hasNEON() &&
1876+
Subtarget->hasDotProd()) {
1877+
setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Legal);
1878+
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1879+
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1880+
setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
1881+
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1882+
}
1883+
18751884
// Handle operations that are only available in non-streaming SVE mode.
18761885
if (Subtarget->isSVEAvailable()) {
18771886
for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64,
@@ -7743,8 +7752,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77437752
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
77447753
return LowerVECTOR_HISTOGRAM(Op, DAG);
77457754
case ISD::PARTIAL_REDUCE_SMLA:
7746-
case ISD::PARTIAL_REDUCE_UMLA:
7747-
case ISD::PARTIAL_REDUCE_SMLA: {
7755+
case ISD::PARTIAL_REDUCE_UMLA: {
77487756
if (SDValue Result = LowerPARTIAL_REDUCE_MLA(Op, DAG))
77497757
return Result;
77507758
return expandPartialReduceMLA(Op.getNode(), DAG);
@@ -27575,8 +27583,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
2757527583
return;
2757627584
case ISD::PARTIAL_REDUCE_UMLA:
2757727585
case ISD::PARTIAL_REDUCE_SMLA: {
27578-
SDValue Res;
27579-
if (Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
27586+
if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
2758027587
Results.push_back(Res);
2758127588
else
2758227589
Results.push_back(expandPartialReduceMLA(N, DAG));
@@ -29531,9 +29538,9 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2953129538
}
2953229539

2953329540
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
29534-
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
29541+
/// of v2i64/v16i8, we cannot directly lower it to a (u|s)dot. We can
2953529542
/// however still make use of the dot product instruction by instead
29536-
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
29543+
/// accumulating over two steps: v16i8 -> v4i32 -> v2i64.
2953729544
SDValue
2953829545
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2953929546
SelectionDAG &DAG) const {
@@ -29568,12 +29575,27 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2956829575
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2956929576
}
2957029577

29571-
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29572-
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29573-
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29574-
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29575-
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29576-
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
29578+
if (Scalable) {
29579+
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29580+
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29581+
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29582+
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29583+
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29584+
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
29585+
}
29586+
29587+
// Fold v4i32 into v2i64
29588+
// SDValues
29589+
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29590+
if (IsUnsigned) {
29591+
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
29592+
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
29593+
} else {
29594+
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
29595+
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
29596+
}
29597+
auto Lo = DAG.getNode(ISD::ADD, DL, MVT::v2i64, Acc, DotNodeLo);
29598+
return DAG.getNode(ISD::ADD, DL, MVT::v2i64, Lo, DotNodeHi);
2957729599
}
2957829600

2957929601
SDValue

0 commit comments

Comments
 (0)