@@ -1992,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
1992
1992
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
1993
1993
case ISD::PARTIAL_REDUCE_SMLA:
1994
1994
case ISD::PARTIAL_REDUCE_UMLA:
1995
+ case ISD::PARTIAL_REDUCE_SUMLA:
1995
1996
return visitPARTIAL_REDUCE_MLA(N);
1996
1997
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
1997
1998
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -12737,26 +12738,27 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
12737
12738
SDValue LHSExtOp = LHS->getOperand(0);
12738
12739
EVT LHSExtOpVT = LHSExtOp.getValueType();
12739
12740
12740
- bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12741
- unsigned NewOpcode =
12742
- ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12743
-
12744
- // Only perform these combines if the target supports folding
12745
- // the extends into the operation.
12746
- if (!TLI.isPartialReduceMLALegalOrCustom(
12747
- NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12748
- TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12749
- return SDValue();
12750
-
12751
12741
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12752
12742
// -> partial_reduce_*mla(acc, x, C)
12753
12743
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
12744
+ // TODO: Make use of partial_reduce_sumla here
12754
12745
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
12755
12746
unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
12756
12747
if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
12757
12748
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
12758
12749
return SDValue();
12759
12750
12751
+ unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
12752
+ ? ISD::PARTIAL_REDUCE_SMLA
12753
+ : ISD::PARTIAL_REDUCE_UMLA;
12754
+
12755
+ // Only perform these combines if the target supports folding
12756
+ // the extends into the operation.
12757
+ if (!TLI.isPartialReduceMLALegalOrCustom(
12758
+ NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12759
+ TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12760
+ return SDValue();
12761
+
12760
12762
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
12761
12763
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
12762
12764
}
@@ -12766,26 +12768,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
12766
12768
return SDValue();
12767
12769
12768
12770
SDValue RHSExtOp = RHS->getOperand(0);
12769
- if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
12771
+ if (LHSExtOpVT != RHSExtOp.getValueType())
12772
+ return SDValue();
12773
+
12774
+ unsigned NewOpc;
12775
+ if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
12776
+ NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12777
+ else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12778
+ NewOpc = ISD::PARTIAL_REDUCE_UMLA;
12779
+ else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12780
+ NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12781
+ else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
12782
+ NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12783
+ std::swap(LHSExtOp, RHSExtOp);
12784
+ } else
12770
12785
return SDValue();
12771
-
12772
- // For a 2-stage extend the signedness of both of the extends must be the
12773
- // same. This is so the node can be folded into only a signed or unsigned
12774
- // node.
12775
- bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12786
+ // For a 2-stage extend the signedness of both of the extends must match
12787
+ // If the mul has the same type, there is no outer extend, and thus we
12788
+ // can simply use the inner extends to pick the result node.
12789
+ // TODO: extend to handle nonneg zext as sext
12776
12790
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12777
- if (ExtIsSigned != NodeIsSigned &&
12778
- Op1.getValueType().getVectorElementType() != AccElemVT )
12791
+ if (Op1.getValueType().getVectorElementType() != AccElemVT &&
12792
+ NewOpc != N->getOpcode() )
12779
12793
return SDValue();
12780
12794
12781
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
12782
- RHSExtOp);
12795
+ // Only perform these combines if the target supports folding
12796
+ // the extends into the operation.
12797
+ if (!TLI.isPartialReduceMLALegalOrCustom(
12798
+ NewOpc, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12799
+ TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12800
+ return SDValue();
12801
+
12802
+ return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
12783
12803
}
12784
12804
12785
12805
// partial.reduce.umla(acc, zext(op), splat(1))
12786
12806
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
12787
12807
// partial.reduce.smla(acc, sext(op), splat(1))
12788
12808
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12809
+ // partial.reduce.sumla(acc, sext(op), splat(1))
12810
+ // -> partial.reduce.smla(acc, op, splat(trunc(1)))
12789
12811
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
12790
12812
SDLoc DL(N);
12791
12813
SDValue Acc = N->getOperand(0);
@@ -12802,7 +12824,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
12802
12824
return SDValue();
12803
12825
12804
12826
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12805
- bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ;
12827
+ bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA ;
12806
12828
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12807
12829
if (Op1IsSigned != NodeIsSigned &&
12808
12830
Op1.getValueType().getVectorElementType() != AccElemVT)
0 commit comments