Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
}

if (EnablePartialReduceNodes && Subtarget->hasNEON() &&
Subtarget->hasDotProd()) {
setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Legal);
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
}

// Handle operations that are only available in non-streaming SVE mode.
if (Subtarget->isSVEAvailable()) {
for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64,
Expand Down Expand Up @@ -27569,6 +27578,12 @@ void AArch64TargetLowering::ReplaceNodeResults(
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
}
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
Expand Down Expand Up @@ -29518,37 +29533,58 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
}

/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
/// of (nx)v2i64/(nx)v16i8, we cannot directly lower it to a (u|s)dot. We can
/// however still make use of the dot product instruction by instead
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
/// accumulating over two steps: (nx)v16i8 -> (nx)v4i32 -> (nx)v2i64.
/// If available, make use of the (U|S)ADDW(B|T) instructions, otherwise
/// the following pattern is emitted:
/// add(add(Acc, ext(EXTRACT_SUBVECTOR(N, 0)), ext(EXTRACT_SUBVECTOR(N,
/// NTy/2))))
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
bool Scalable = Op.getValueType().isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();

SDLoc DL(Op);

SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);

SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
assert((Scalable && ResultVT == MVT::nxv2i64 &&
LHS.getValueType() == MVT::nxv16i8) ||
(!Scalable && ResultVT == MVT::v2i64 &&
LHS.getValueType() == MVT::v16i8));

EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
DAG.getConstant(0, DL, DotVT), LHS, RHS);

bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
if (Scalable &&
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
}

unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
// Fold (nx)v4i32 into (nx)v2i64
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
if (IsUnsigned) {
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
} else {
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
}
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
}

SDValue
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,17 @@ defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>;
defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>;
}

let Predicates = [HasNEON, HasDotProd] in {
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
(v4i32 (UDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
(v4i32 (SDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
def : Pat<(v2i32 (partial_reduce_umla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
(v2i32 (UDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
def : Pat<(v2i32 (partial_reduce_smla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
(v2i32 (SDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
} // End HasNEON, HasDotProd

// ARMv8.6-A BFloat
let Predicates = [HasNEON, HasBF16] in {
defm BFDOT : SIMDThreeSameVectorBFDot<1, "bfdot">;
Expand Down
Loading
Loading