Skip to content
35 changes: 35 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,25 @@ class TargetLoweringBase {
getCondCodeAction(CC, VT) == Custom;
}

/// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
/// InputVT should be treated. Either it's legal, needs to be promoted to a
/// larger size, needs to be expanded to some other code sequence, or the
/// target has a custom expander for it.
LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
"Table isn't big enough!");
return PartialReduceMLAActions[AccI][InputI];
}

/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
/// legal or custom for this target.
bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
getPartialReduceMLAAction(AccVT, InputVT) == Custom;
}

/// If the action for this operation is to promote, this method returns the
/// ValueType to promote to.
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
Expand Down Expand Up @@ -2704,6 +2723,16 @@ class TargetLoweringBase {
setCondCodeAction(CCs, VT, Action);
}

/// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
/// type InputVT should be treated by the target. Either it's legal, needs to
/// be promoted to a larger size, needs to be expanded to some other code
/// sequence, or the target has a custom expander for it.
void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
LegalizeAction Action) {
assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
}

/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
/// to trying a larger integer/fp until it can find one that works. If that
/// default is insufficient, this method can be used by the target to override
Expand Down Expand Up @@ -3650,6 +3679,12 @@ class TargetLoweringBase {
/// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];

/// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
/// nodes, keep a LegalizeAction which indicates how instruction selection
/// should deal with this operation.
LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
[MVT::VALUETYPE_SIZE];

ValueTypeActionImpl ValueTypeActions;

private:
Expand Down
55 changes: 55 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ namespace {
SDValue visitMGATHER(SDNode *N);
SDValue visitMSCATTER(SDNode *N);
SDValue visitMHISTOGRAM(SDNode *N);
SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
SDValue visitVPGATHER(SDNode *N);
SDValue visitVPSCATTER(SDNode *N);
SDValue visitVP_STRIDED_LOAD(SDNode *N);
Expand Down Expand Up @@ -621,6 +622,8 @@ namespace {
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI);
SDValue foldMulPARTIAL_REDUCE_MLA(SDNode *N);
SDValue foldExtendPARTIAL_REDUCE_MLA(SDNode *N);

SDValue CombineExtLoad(SDNode *N);
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
Expand Down Expand Up @@ -1972,6 +1975,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::MSCATTER: return visitMSCATTER(N);
case ISD::MSTORE: return visitMSTORE(N);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
Expand Down Expand Up @@ -12497,6 +12503,55 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
// Splat(1)) into
// PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
// Splat(1)) into
// PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
SDLoc DL(N);

SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);

if (Op1->getOpcode() != ISD::MUL)
return SDValue();

APInt ConstantOne;
if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
!ConstantOne.isOne())
return SDValue();

SDValue ExtMulOpLHS = Op1->getOperand(0);
SDValue ExtMulOpRHS = Op1->getOperand(1);
unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
return SDValue();

SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
EVT MulOpLHSVT = MulOpLHS.getValueType();
if (MulOpLHSVT != MulOpRHS.getValueType())
return SDValue();
// Only perform the DAG combine if there is custom lowering provided by the
// target
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), MulOpLHSVT))
return SDValue();

bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
if (LHSIsSigned != RHSIsSigned)
return SDValue();

unsigned NewOpcode =
LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Op0, MulOpLHS,
MulOpRHS);
}

SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECTOR_COMPRESS:
case ISD::SCMP:
case ISD::UCMP:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
break;
case ISD::SMULFIX:
Expand Down Expand Up @@ -524,6 +522,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
break;
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
Node->getOperand(1).getValueType());
break;

#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
case ISD::VPID: { \
Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::SET_FPENV, VT, Expand);
setOperationAction(ISD::RESET_FPENV, VT, Expand);

// PartialReduceMLA operations default to expand.
setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
Expand);
for (MVT InputVT : MVT::all_valuetypes())
setPartialReduceMLAAction(VT, InputVT, Expand);
}

// Most targets ignore the @llvm.prefetch intrinsic.
Expand Down