@@ -1564,7 +1564,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15641564 ISD::MUL, ISD::SDIV, ISD::UDIV,
15651565 ISD::SREM, ISD::UREM, ISD::INSERT_VECTOR_ELT,
15661566 ISD::ABS, ISD::CTPOP, ISD::VECTOR_SHUFFLE,
1567- ISD::VSELECT});
1567+ ISD::VSELECT, ISD::VECREDUCE_ADD });
15681568
15691569 if (Subtarget.hasVendorXTHeadMemPair())
15701570 setTargetDAGCombine({ISD::LOAD, ISD::STORE});
@@ -18144,25 +18144,38 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
1814418144// (iX ctpop (bitcast (vXi1 A)))
1814518145// ->
1814618146// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
18147+ // and
18148+ // (iN reduce.add (zext (vXi1 A to vXiN))
18149+ // ->
18150+ // (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
1814718151// FIXME: It's complicated to match all the variations of this after type
1814818152// legalization so we only handle the pre-type legalization pattern, but that
1814918153// requires the fixed vector type to be legal.
18150- static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
18151- const RISCVSubtarget &Subtarget) {
18154+ static SDValue combineToVCPOP(SDNode *N, SelectionDAG &DAG,
18155+ const RISCVSubtarget &Subtarget) {
18156+ unsigned Opc = N->getOpcode();
18157+ assert((Opc == ISD::CTPOP || Opc == ISD::VECREDUCE_ADD) &&
18158+ "Unexpected opcode");
1815218159 EVT VT = N->getValueType(0);
1815318160 if (!VT.isScalarInteger())
1815418161 return SDValue();
1815518162
1815618163 SDValue Src = N->getOperand(0);
1815718164
18158- // Peek through zero_extend. It doesn't change the count.
18159- if (Src.getOpcode() == ISD::ZERO_EXTEND)
18160- Src = Src.getOperand(0);
18165+ if (Opc == ISD::CTPOP) {
18166+ // Peek through zero_extend. It doesn't change the count.
18167+ if (Src.getOpcode() == ISD::ZERO_EXTEND)
18168+ Src = Src.getOperand(0);
1816118169
18162- if (Src.getOpcode() != ISD::BITCAST)
18163- return SDValue();
18170+ if (Src.getOpcode() != ISD::BITCAST)
18171+ return SDValue();
18172+ Src = Src.getOperand(0);
18173+ } else if (Opc == ISD::VECREDUCE_ADD) {
18174+ if (Src.getOpcode() != ISD::ZERO_EXTEND)
18175+ return SDValue();
18176+ Src = Src.getOperand(0);
18177+ }
1816418178
18165- Src = Src.getOperand(0);
1816618179 EVT SrcEVT = Src.getValueType();
1816718180 if (!SrcEVT.isSimple())
1816818181 return SDValue();
@@ -18172,11 +18185,28 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
1817218185 if (!SrcMVT.isVector() || SrcMVT.getVectorElementType() != MVT::i1)
1817318186 return SDValue();
1817418187
18175- if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
18176- return SDValue();
18188+ // Check that destination type is large enough to hold result without
18189+ // overflow.
18190+ if (Opc == ISD::VECREDUCE_ADD) {
18191+ unsigned EltSize = SrcMVT.getScalarSizeInBits();
18192+ unsigned MinSize = SrcMVT.getSizeInBits().getKnownMinValue();
18193+ unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
18194+ unsigned MaxVLMAX = SrcMVT.isFixedLengthVector()
18195+ ? SrcMVT.getVectorNumElements()
18196+ : RISCVTargetLowering::computeVLMAX(
18197+ VectorBitsMax, EltSize, MinSize);
18198+ if (VT.getFixedSizeInBits() < Log2_32(MaxVLMAX) + 1)
18199+ return SDValue();
18200+ }
1817718201
18178- MVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
18179- Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
18202+ MVT ContainerVT = SrcMVT;
18203+ if (SrcMVT.isFixedLengthVector()) {
18204+ if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
18205+ return SDValue();
18206+
18207+ ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
18208+ Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
18209+ }
1818018210
1818118211 SDLoc DL(N);
1818218212 auto [Mask, VL] = getDefaultVLOps(SrcMVT, ContainerVT, DL, DAG, Subtarget);
@@ -19258,7 +19288,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1925819288 return SDValue();
1925919289 }
1926019290 case ISD::CTPOP:
19261- if (SDValue V = combineScalarCTPOPToVCPOP(N, DAG, Subtarget))
19291+ case ISD::VECREDUCE_ADD:
19292+ if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
1926219293 return V;
1926319294 break;
1926419295 }
0 commit comments