Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1484,8 +1484,9 @@ enum NodeType {
VECREDUCE_UMIN,

// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
// The partial reduction nodes sign or zero extend Input1 and Input2 to the
// element type of Accumulator before multiplying their results.
// The partial reduction nodes sign or zero extend Input1 and Input2
// (with the extension kind noted below) to the element type of
// Accumulator before multiplying their results.
// This result is concatenated to the Accumulator, and this is then reduced,
// using addition, to the result type.
// The output is only expected to either be given to another partial reduction
Expand All @@ -1497,8 +1498,9 @@ enum NodeType {
// multiple of the number of elements in the Accumulator / output type.
// Input1 and Input2 must have an element type which is the same as or smaller
// than the element type of the Accumulator and output.
PARTIAL_REDUCE_SMLA,
PARTIAL_REDUCE_UMLA,
PARTIAL_REDUCE_SMLA, // sext, sext
PARTIAL_REDUCE_UMLA, // zext, zext
PARTIAL_REDUCE_SUMLA, // sext, zext

// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
Expand Down
53 changes: 37 additions & 16 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
Expand Down Expand Up @@ -12675,19 +12676,19 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();

bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;

// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
// TODO: Make use of partial_reduce_sumla here
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
return SDValue();

unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
? ISD::PARTIAL_REDUCE_SMLA
: ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
}
Expand All @@ -12697,26 +12698,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
return SDValue();

SDValue RHSExtOp = RHS->getOperand(0);
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
if (LHSExtOpVT != RHSExtOp.getValueType())
return SDValue();

// For a 2-stage extend the signedness of both of the extends must be the
// same. This is so the node can be folded into only a signed or unsigned
// node.
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
unsigned NewOpc = ISD::PARTIAL_REDUCE_SMLA;
// For a 2-stage extend the signedness of both of the extends must match
// If the mul has the same type, there is no outer extend, and thus we
// can simply use the inner extends to pick the result node.
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (ExtIsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();

return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
RHSExtOp);
if (Op1.getValueType().getVectorElementType() != AccElemVT) {
// TODO: Split this into canonicalization rules
if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND &&
(N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ||
N->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA))
NewOpc = ISD::PARTIAL_REDUCE_SMLA;
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND &&
N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA)
NewOpc = ISD::PARTIAL_REDUCE_UMLA;
else
return SDValue();
} else {
// TODO: Add canonicalization rule
if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
NewOpc = ISD::PARTIAL_REDUCE_SMLA;
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
NewOpc = ISD::PARTIAL_REDUCE_UMLA;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on what you mean by this TODO? I'm not sure I follow why we'd want to handle a zext as a sext in this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A zext nonneg is a zext for which the high bit is known to be zero, and thus is equivalent to a sext. We canonicalize such cases to zext nonneg. As such, handling zext nonneg would allow us to recognize more parial_reduce_smla cases which we'd currently miss. Note that for partial_reduce_sumla enabled targets, this might not matter since we'd just chose an alternate instruction.

else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
else
// TODO: Handle the swapped sumla case here
return SDValue();
}
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}

// partial.reduce.umla(acc, zext(op), splat(1))
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
// partial.reduce.smla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
Expand All @@ -12738,7 +12759,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
return SDValue();

bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
Expand Down
15 changes: 13 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {

case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
break;

Expand Down Expand Up @@ -2090,6 +2091,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
break;
}
Expand Down Expand Up @@ -2876,12 +2878,21 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,

SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
SmallVector<SDValue, 1> NewOps(N->ops());
if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
switch (N->getOpcode()) {
case ISD::PARTIAL_REDUCE_SMLA:
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
NewOps[2] = SExtPromotedInteger(N->getOperand(2));
} else {
break;
case ISD::PARTIAL_REDUCE_UMLA:
NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
break;
case ISD::PARTIAL_REDUCE_SUMLA:
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
break;
default:
llvm_unreachable("unexpected opcode");
}
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
Node->getOperand(1).getValueType());
break;
Expand Down Expand Up @@ -1210,6 +1211,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
return;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
case ISD::VECREDUCE_SEQ_FADD:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
break;
}
Expand Down Expand Up @@ -3454,6 +3455,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
break;
}
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7967,7 +7967,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA: {
[[maybe_unused]] EVT AccVT = N1.getValueType();
[[maybe_unused]] EVT Input1VT = N2.getValueType();
[[maybe_unused]] EVT Input2VT = N3.getValueType();
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
return "partial_reduce_umla";
case ISD::PARTIAL_REDUCE_SMLA:
return "partial_reduce_smla";
case ISD::PARTIAL_REDUCE_SUMLA:
return "partial_reduce_sumla";

// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
Expand Down
22 changes: 16 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11887,13 +11887,23 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
EVT ExtMulOpVT =
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
MulOpVT.getVectorElementCount());
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND;

if (ExtMulOpVT != MulOpVT) {
MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
switch (N->getOpcode()) {
case ISD::PARTIAL_REDUCE_SMLA:
MulLHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
MulRHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulRHS);
break;
case ISD::PARTIAL_REDUCE_UMLA:
MulLHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulLHS);
MulRHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
break;
case ISD::PARTIAL_REDUCE_SUMLA:
MulLHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
MulRHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
break;
default:
llvm_unreachable("unexpected opcode");
}
}
SDValue Input = MulLHS;
APInt ConstantOne;
Expand Down
18 changes: 15 additions & 3 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1874,8 +1874,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
// Other pairs will default to 'Expand'.
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Custom);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this interface (and internal table) to include the kind of extension being done? (signed/unsigned or mixed)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just throwing this out there but perhaps change it to setReduceAction(Opcode, ResultType, OperandType)? to make it easier to add more reductions in the future. This might also be a way to relax the result type requirements of the current VECREDUCE_*** nodes.

setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);

setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
}
Expand Down Expand Up @@ -7745,6 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerVECTOR_HISTOGRAM(Op, DAG);
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
Expand Down Expand Up @@ -29532,13 +29533,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
// No support for sumla forms, let generic legalization handle them
if (Op->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA)
return SDValue();

SDLoc DL(Op);

SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
EVT OpVT = LHS.getValueType();

// These two are legal...
if ((ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv8i16) ||
(ResultVT == MVT::nxv4i32 && OpVT == MVT::nxv16i8))
return Op;

assert(ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv16i8);

SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
Expand Down
17 changes: 15 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8240,6 +8240,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerADJUST_TRAMPOLINE(Op, DAG);
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
return lowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
Expand Down Expand Up @@ -8391,8 +8392,20 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
SDValue B = Op.getOperand(2);
assert(A.getSimpleValueType() == B.getSimpleValueType() &&
A.getSimpleValueType().getVectorElementType() == MVT::i8);
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
unsigned Opc;
switch (Op.getOpcode()) {
case ISD::PARTIAL_REDUCE_SMLA:
Opc = RISCVISD::VQDOT_VL;
break;
case ISD::PARTIAL_REDUCE_UMLA:
Opc = RISCVISD::VQDOTU_VL;
break;
case ISD::PARTIAL_REDUCE_SUMLA:
Opc = RISCVISD::VQDOTSU_VL;
break;
default:
llvm_unreachable("Unexpected opcode");
}
auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
}
Expand Down
Loading