@@ -1992,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
19921992 case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
19931993 case ISD::PARTIAL_REDUCE_SMLA:
19941994 case ISD::PARTIAL_REDUCE_UMLA:
1995+ case ISD::PARTIAL_REDUCE_SUMLA:
19951996 return visitPARTIAL_REDUCE_MLA(N);
19961997 case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
19971998 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -12737,26 +12738,27 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1273712738 SDValue LHSExtOp = LHS->getOperand(0);
1273812739 EVT LHSExtOpVT = LHSExtOp.getValueType();
1273912740
12740- bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12741- unsigned NewOpcode =
12742- ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12743-
12744- // Only perform these combines if the target supports folding
12745- // the extends into the operation.
12746- if (!TLI.isPartialReduceMLALegalOrCustom(
12747- NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12748- TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12749- return SDValue();
12750-
1275112741 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1275212742 // -> partial_reduce_*mla(acc, x, C)
1275312743 if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
12744+ // TODO: Make use of partial_reduce_sumla here
1275412745 APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
1275512746 unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
1275612747 if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
1275712748 (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
1275812749 return SDValue();
1275912750
12751+ unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
12752+ ? ISD::PARTIAL_REDUCE_SMLA
12753+ : ISD::PARTIAL_REDUCE_UMLA;
12754+
12755+ // Only perform these combines if the target supports folding
12756+ // the extends into the operation.
12757+ if (!TLI.isPartialReduceMLALegalOrCustom(
12758+ NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12759+ TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12760+ return SDValue();
12761+
1276012762 return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
1276112763 DAG.getConstant(CTrunc, DL, LHSExtOpVT));
1276212764 }
@@ -12766,26 +12768,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1276612768 return SDValue();
1276712769
1276812770 SDValue RHSExtOp = RHS->getOperand(0);
12769- if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
12771+ if (LHSExtOpVT != RHSExtOp.getValueType())
12772+ return SDValue();
12773+
12774+ unsigned NewOpc;
12775+ if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
12776+ NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12777+ else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12778+ NewOpc = ISD::PARTIAL_REDUCE_UMLA;
12779+ else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12780+ NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12781+ else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
12782+ NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12783+ std::swap(LHSExtOp, RHSExtOp);
12784+ } else
1277012785 return SDValue();
12771-
12772- // For a 2-stage extend the signedness of both of the extends must be the
12773- // same. This is so the node can be folded into only a signed or unsigned
12774- // node.
12775- bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12786+ // For a 2-stage extend the signedness of both of the extends must match
12787+ // If the mul has the same type, there is no outer extend, and thus we
12788+ // can simply use the inner extends to pick the result node.
12789+ // TODO: extend to handle nonneg zext as sext
1277612790 EVT AccElemVT = Acc.getValueType().getVectorElementType();
12777- if (ExtIsSigned != NodeIsSigned &&
12778- Op1.getValueType().getVectorElementType() != AccElemVT )
12791+ if (Op1.getValueType().getVectorElementType() != AccElemVT &&
12792+ NewOpc != N->getOpcode() )
1277912793 return SDValue();
1278012794
12781- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
12782- RHSExtOp);
12795+ // Only perform these combines if the target supports folding
12796+ // the extends into the operation.
12797+ if (!TLI.isPartialReduceMLALegalOrCustom(
12798+ NewOpc, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12799+ TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12800+ return SDValue();
12801+
12802+ return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
1278312803}
1278412804
1278512805// partial.reduce.umla(acc, zext(op), splat(1))
1278612806// -> partial.reduce.umla(acc, op, splat(trunc(1)))
1278712807// partial.reduce.smla(acc, sext(op), splat(1))
1278812808// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12809+ // partial.reduce.sumla(acc, sext(op), splat(1))
12810+ // -> partial.reduce.smla(acc, op, splat(trunc(1)))
1278912811SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1279012812 SDLoc DL(N);
1279112813 SDValue Acc = N->getOperand(0);
@@ -12802,7 +12824,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1280212824 return SDValue();
1280312825
1280412826 bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12805- bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ;
12827+ bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA ;
1280612828 EVT AccElemVT = Acc.getValueType().getVectorElementType();
1280712829 if (Op1IsSigned != NodeIsSigned &&
1280812830 Op1.getValueType().getVectorElementType() != AccElemVT)
0 commit comments