@@ -5136,6 +5136,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
51365136 if (!unrollExpandedOp ())
51375137 Res = WidenVecRes_UnaryOpWithTwoResults (N, ResNo);
51385138 break ;
5139+ case ISD::PARTIAL_REDUCE_UMLA:
5140+ case ISD::PARTIAL_REDUCE_SMLA:
5141+ Res = WidenVecRes_PARTIAL_REDUCE_MLA (N);
5142+ break ;
51395143 }
51405144 }
51415145
@@ -6995,6 +6999,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_STRICT_FSETCC(SDNode *N) {
69956999 return DAG.getBuildVector (WidenVT, dl, Scalars);
69967000}
69977001
7002+ // Widening the result of a partial reductions is implemented by
7003+ // accumulating into a wider (zero-padded) vector, then incrementally
7004+ // reducing that (extract half vector and add) until it fits
7005+ // the original type.
7006+ SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA (SDNode *N) {
7007+ SDLoc DL (N);
7008+ EVT VT = N->getValueType (0 );
7009+ EVT WideAccVT = TLI.getTypeToTransformTo (*DAG.getContext (),
7010+ N->getOperand (0 ).getValueType ());
7011+ SDValue Zero = DAG.getConstant (0 , DL, WideAccVT);
7012+ SDValue MulOp1 = N->getOperand (1 );
7013+ SDValue MulOp2 = N->getOperand (2 );
7014+ SDValue Acc = DAG.getInsertSubvector (DL, Zero, N->getOperand (0 ), 0 );
7015+ SDValue WidenedRes =
7016+ DAG.getNode (N->getOpcode (), DL, WideAccVT, Acc, MulOp1, MulOp2);
7017+ while (ElementCount::isKnownLT (
7018+ VT.getVectorElementCount (),
7019+ WidenedRes.getValueType ().getVectorElementCount ())) {
7020+ EVT HalfVT =
7021+ WidenedRes.getValueType ().getHalfNumVectorElementsVT (*DAG.getContext ());
7022+ SDValue Lo = DAG.getExtractSubvector (DL, HalfVT, WidenedRes, 0 );
7023+ SDValue Hi = DAG.getExtractSubvector (DL, HalfVT, WidenedRes,
7024+ HalfVT.getVectorMinNumElements ());
7025+ WidenedRes = DAG.getNode (ISD::ADD, DL, HalfVT, Lo, Hi);
7026+ }
7027+ return DAG.getInsertSubvector (DL, Zero, WidenedRes, 0 );
7028+ }
7029+
69987030// ===----------------------------------------------------------------------===//
69997031// Widen Vector Operand
70007032// ===----------------------------------------------------------------------===//
@@ -7127,6 +7159,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
71277159 case ISD::VP_REDUCE_FMINIMUM:
71287160 Res = WidenVecOp_VP_REDUCE (N);
71297161 break ;
7162+ case ISD::PARTIAL_REDUCE_UMLA:
7163+ case ISD::PARTIAL_REDUCE_SMLA:
7164+ Res = WidenVecOp_PARTIAL_REDUCE_MLA (N);
7165+ break ;
71307166 case ISD::VP_CTTZ_ELTS:
71317167 case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
71327168 Res = WidenVecOp_VP_CttzElements (N);
@@ -8026,6 +8062,21 @@ SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
80268062 {Source, Mask, N->getOperand (2 )}, N->getFlags ());
80278063}
80288064
8065+ SDValue DAGTypeLegalizer::WidenVecOp_PARTIAL_REDUCE_MLA (SDNode *N) {
8066+ // Widening of multiplicant operands only. The result and accumulator
8067+ // should already be legal types.
8068+ SDLoc DL (N);
8069+ EVT WideOpVT = TLI.getTypeToTransformTo (*DAG.getContext (),
8070+ N->getOperand (1 ).getValueType ());
8071+ SDValue Acc = N->getOperand (0 );
8072+ SDValue WidenedOp1 = DAG.getInsertSubvector (
8073+ DL, DAG.getConstant (0 , DL, WideOpVT), N->getOperand (1 ), 0 );
8074+ SDValue WidenedOp2 = DAG.getInsertSubvector (
8075+ DL, DAG.getConstant (0 , DL, WideOpVT), N->getOperand (2 ), 0 );
8076+ return DAG.getNode (N->getOpcode (), DL, Acc.getValueType (), Acc, WidenedOp1,
8077+ WidenedOp2);
8078+ }
8079+
80298080// ===----------------------------------------------------------------------===//
80308081// Vector Widening Utilities
80318082// ===----------------------------------------------------------------------===//
0 commit comments