Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// FADDP custom lowering
for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
setOperationAction(ISD::FADD, VT, Custom);

if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
}

} else /* !isNeonAvailable */ {
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
Expand Down Expand Up @@ -27569,6 +27576,12 @@ void AArch64TargetLowering::ReplaceNodeResults(
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just noticed this, but Is this code actually used? (normally this is only needed when the result type is not legal, but the input is).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It definitely was at one point in history, to support the v16i8 -> v4i64 cases. But as we're now handling that differently (by splitting the accumulator into v2i64) this code is never hit with the current test cases. Removed.

case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
Expand Down Expand Up @@ -29518,37 +29531,60 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
}

/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
/// of (nx)v2i64/(nx)v16i8, we cannot directly lower it to a (u|s)dot. We can
/// however still make use of the dot product instruction by instead
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
/// accumulating over two steps: (nx)v16i8 -> (nx)v4i32 -> (nx)v2i64.
/// If available, make use of the (U|S)ADDW(B|T) instructions, otherwise
/// the following pattern is emitted:
/// add(add(Acc, ext(EXTRACT_SUBVECTOR(N, 0)), ext(EXTRACT_SUBVECTOR(N,
/// NTy/2))))
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
bool Scalable = Op.getValueType().isScalableVector();

assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
"SVE or StreamingSVE must be available when using scalable vectors.");
assert((Scalable || Subtarget->hasDotProd()) &&
"Dotprod must be available when targeting NEON dot product "
"instructions.");

SDLoc DL(Op);

SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);

SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
assert((Scalable && ResultVT == MVT::nxv2i64 &&
LHS.getValueType() == MVT::nxv16i8) ||
(!Scalable && ResultVT == MVT::v2i64 &&
LHS.getValueType() == MVT::v16i8));

EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
DAG.getConstant(0, DL, DotVT), LHS, RHS);

bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
if (Scalable &&
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
}

unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
// Fold (nx)v4i32 into (nx)v2i64
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
if (IsUnsigned) {
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
} else {
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
}
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
}

SDValue
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,17 @@ defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>;
defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>;
}

let Predicates = [HasNEON, HasDotProd] in {
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
(v4i32 (UDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
(v4i32 (SDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
def : Pat<(v2i32 (partial_reduce_umla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
(v2i32 (UDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
def : Pat<(v2i32 (partial_reduce_smla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
(v2i32 (SDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
} // End HasNEON, HasDotProd

// ARMv8.6-A BFloat
let Predicates = [HasNEON, HasBF16] in {
defm BFDOT : SIMDThreeSameVectorBFDot<1, "bfdot">;
Expand Down
Loading
Loading