@@ -2041,9 +2041,13 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
20412041 return true;
20422042
20432043 EVT VT = EVT::getEVT(I->getType());
2044- return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
2045- VT != MVT::nxv8i16 && VT != MVT::v4i64 && VT != MVT::v4i32 &&
2046- VT != MVT::v2i32 && VT != MVT::v8i16;
2044+ auto Op1 = I->getOperand(1);
2045+ EVT Op1VT = EVT::getEVT(Op1->getType());
2046+ if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
2047+ (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
2048+ VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
2049+ return false;
2050+ return true;
20472051}
20482052
20492053bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21793,36 +21797,34 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
2179321797 Intrinsic::experimental_vector_partial_reduce_add &&
2179421798 "Expected a partial reduction node");
2179521799
21796- bool Scalable = N->getValueType(0).isScalableVector();
21797- if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
21800+ if (!Subtarget->isSVEorStreamingSVEAvailable())
2179821801 return SDValue();
2179921802
2180021803 SDLoc DL(N);
2180121804
21802- auto Accumulator = N->getOperand(1);
21805+ auto Acc = N->getOperand(1);
2180321806 auto ExtInput = N->getOperand(2);
2180421807
21805- EVT AccumulatorType = Accumulator .getValueType();
21806- EVT AccumulatorElementType = AccumulatorType .getVectorElementType();
21808+ EVT AccVT = Acc .getValueType();
21809+ EVT AccElemVT = AccVT .getVectorElementType();
2180721810
21808- if (ExtInput.getValueType().getVectorElementType() != AccumulatorElementType )
21811+ if (ExtInput.getValueType().getVectorElementType() != AccElemVT )
2180921812 return SDValue();
2181021813
2181121814 unsigned ExtInputOpcode = ExtInput->getOpcode();
2181221815 if (!ISD::isExtOpcode(ExtInputOpcode))
2181321816 return SDValue();
2181421817
2181521818 auto Input = ExtInput->getOperand(0);
21816- EVT InputType = Input.getValueType();
21819+ EVT InputVT = Input.getValueType();
2181721820
2181821821 // To do this transformation, output element size needs to be double input
2181921822 // element size, and output number of elements needs to be half the input
2182021823 // number of elements
21821- if (!(InputType.getVectorElementType().getSizeInBits() * 2 ==
21822- AccumulatorElementType.getSizeInBits()) ||
21823- !(AccumulatorType.getVectorElementCount() * 2 ==
21824- InputType.getVectorElementCount()) ||
21825- !(AccumulatorType.isScalableVector() == InputType.isScalableVector()))
21824+ if (InputVT.getVectorElementType().getSizeInBits() * 2 !=
21825+ AccElemVT.getSizeInBits() ||
21826+ AccVT.getVectorElementCount() * 2 != InputVT.getVectorElementCount() ||
21827+ AccVT.isScalableVector() != InputVT.isScalableVector())
2182621828 return SDValue();
2182721829
2182821830 bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
@@ -21831,13 +21833,12 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
2183121833 auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
2183221834 : Intrinsic::aarch64_sve_uaddwt;
2183321835
21834- auto BottomID =
21835- DAG.getTargetConstant(BottomIntrinsic, DL, AccumulatorElementType);
21836- auto BottomNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType,
21837- BottomID, Accumulator, Input);
21838- auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccumulatorElementType);
21839- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType, TopID,
21840- BottomNode, Input);
21836+ auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
21837+ auto BottomNode =
21838+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
21839+ auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
21840+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
21841+ Input);
2184121842}
2184221843
2184321844static SDValue performIntrinsicCombine(SDNode *N,
0 commit comments