Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 55 additions & 60 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21738,74 +21738,72 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,

SDLoc DL(N);

// The narrower of the two operands. Used as the accumulator
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
if (MulOp->getOpcode() != ISD::MUL)
SDValue Op2 = N->getOperand(2);
if (Op2->getOpcode() != ISD::MUL ||
!ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
!ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
return SDValue();

auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);

if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
!ISD::isExtOpcode(ExtB->getOpcode()))
return SDValue();
bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
SDValue Acc = N->getOperand(1);
SDValue Mul = N->getOperand(2);
SDValue ExtMulOpLHS = Mul->getOperand(0);
SDValue ExtMulOpRHS = Mul->getOperand(1);

auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
if (A.getValueType() != B.getValueType())
SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
return SDValue();

EVT ReducedType = N->getValueType(0);
EVT MulSrcType = A.getValueType();
EVT ReducedVT = N->getValueType(0);
EVT MulSrcVT = MulOpLHS.getValueType();

// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
!(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
!(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
!(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
!(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
!(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();

bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
if (AIsSigned != BIsSigned) {
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();

bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();

Opcode = AArch64ISD::USDOT;
// USDOT expects the signed operand to be last
if (!BIsSigned)
std::swap(A, B);
} else if (AIsSigned)
if (!MulOpRHSIsSigned)
std::swap(MulOpLHS, MulOpRHS);
} else if (MulOpLHSIsSigned)
Opcode = AArch64ISD::SDOT;
else
Opcode = AArch64ISD::UDOT;

// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
EVT ReducedTypeI32 =
(ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
EVT ReducedVTI32 =
(ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;

auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
DAG.getConstant(0, DL, ReducedTypeI32), A, B);
auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
Extended);
SDValue DotI32 =
DAG.getNode(Opcode, DL, ReducedVTI32,
DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
}

return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
}

SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
Expand All @@ -21822,32 +21820,29 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,

SDLoc DL(N);

auto Acc = N->getOperand(1);
auto ExtInput = N->getOperand(2);

EVT AccVT = Acc.getValueType();
EVT AccElemVT = AccVT.getVectorElementType();

if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
return SDValue();

unsigned ExtInputOpcode = ExtInput->getOpcode();
if (!ISD::isExtOpcode(ExtInputOpcode))
SDValue Acc = N->getOperand(1);
SDValue Ext = N->getOperand(2);
EVT AccVT = Acc.getValueType();
EVT ExtVT = Ext.getValueType();
if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
return SDValue();

auto Input = ExtInput->getOperand(0);
EVT InputVT = Input.getValueType();
SDValue ExtOp = Ext->getOperand(0);
EVT ExtOpVT = ExtOp.getValueType();

if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
!(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();

bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
unsigned BottomOpcode =
ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
}

static SDValue performIntrinsicCombine(SDNode *N,
Expand All @@ -21859,9 +21854,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
default:
break;
case Intrinsic::experimental_vector_partial_reduce_add: {
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
Expand Down
Loading