@@ -22515,138 +22515,6 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
2251522515 return SDValue();
2251622516}
2251722517
22518- SDValue tryLowerPartialReductionToDot(SDNode *N,
22519- const AArch64Subtarget *Subtarget,
22520- SelectionDAG &DAG) {
22521-
22522- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
22523- getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
22524- "Expected a partial reduction node");
22525-
22526- bool Scalable = N->getValueType(0).isScalableVector();
22527- if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
22528- return SDValue();
22529- if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
22530- return SDValue();
22531-
22532- SDLoc DL(N);
22533-
22534- SDValue Op2 = N->getOperand(2);
22535- unsigned Op2Opcode = Op2->getOpcode();
22536- SDValue MulOpLHS, MulOpRHS;
22537- bool MulOpLHSIsSigned, MulOpRHSIsSigned;
22538- if (ISD::isExtOpcode(Op2Opcode)) {
22539- MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
22540- MulOpLHS = Op2->getOperand(0);
22541- MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
22542- } else if (Op2Opcode == ISD::MUL) {
22543- SDValue ExtMulOpLHS = Op2->getOperand(0);
22544- SDValue ExtMulOpRHS = Op2->getOperand(1);
22545-
22546- unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
22547- unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
22548- if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
22549- !ISD::isExtOpcode(ExtMulOpRHSOpcode))
22550- return SDValue();
22551-
22552- MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
22553- MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
22554-
22555- MulOpLHS = ExtMulOpLHS->getOperand(0);
22556- MulOpRHS = ExtMulOpRHS->getOperand(0);
22557-
22558- if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
22559- return SDValue();
22560- } else
22561- return SDValue();
22562-
22563- SDValue Acc = N->getOperand(1);
22564- EVT ReducedVT = N->getValueType(0);
22565- EVT MulSrcVT = MulOpLHS.getValueType();
22566-
22567- // Dot products operate on chunks of four elements so there must be four times
22568- // as many elements in the wide type
22569- if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
22570- !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
22571- !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
22572- !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22573- !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22574- !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22575- return SDValue();
22576-
22577- // If the extensions are mixed, we should lower it to a usdot instead
22578- unsigned Opcode = 0;
22579- if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
22580- if (!Subtarget->hasMatMulInt8())
22581- return SDValue();
22582-
22583- bool Scalable = N->getValueType(0).isScalableVT();
22584- // There's no nxv2i64 version of usdot
22585- if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
22586- return SDValue();
22587-
22588- Opcode = AArch64ISD::USDOT;
22589- // USDOT expects the signed operand to be last
22590- if (!MulOpRHSIsSigned)
22591- std::swap(MulOpLHS, MulOpRHS);
22592- } else
22593- Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22594-
22595- // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
22596- // product followed by a zero / sign extension
22597- if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22598- (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22599- EVT ReducedVTI32 =
22600- (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22601-
22602- SDValue DotI32 =
22603- DAG.getNode(Opcode, DL, ReducedVTI32,
22604- DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
22605- SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
22606- return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
22607- }
22608-
22609- return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
22610- }
22611-
22612- SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
22613- const AArch64Subtarget *Subtarget,
22614- SelectionDAG &DAG) {
22615-
22616- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
22617- getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
22618- "Expected a partial reduction node");
22619-
22620- if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22621- return SDValue();
22622-
22623- SDLoc DL(N);
22624-
22625- if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
22626- return SDValue();
22627- SDValue Acc = N->getOperand(1);
22628- SDValue Ext = N->getOperand(2);
22629- EVT AccVT = Acc.getValueType();
22630- EVT ExtVT = Ext.getValueType();
22631- if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
22632- return SDValue();
22633-
22634- SDValue ExtOp = Ext->getOperand(0);
22635- EVT ExtOpVT = ExtOp.getValueType();
22636-
22637- if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22638- !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22639- !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22640- return SDValue();
22641-
22642- bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
22643- unsigned BottomOpcode =
22644- ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22645- unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22646- SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
22647- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
22648- }
22649-
2265022518static SDValue combineSVEBitSel(unsigned IID, SDNode *N, SelectionDAG &DAG) {
2265122519 SDLoc DL(N);
2265222520 EVT VT = N->getValueType(0);
@@ -22679,17 +22547,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
2267922547 switch (IID) {
2268022548 default:
2268122549 break;
22682- case Intrinsic::vector_partial_reduce_add: {
22683- if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
22684- return Dot;
22685- if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
22686- return WideAdd;
22687- SDLoc DL(N);
22688- SDValue Input = N->getOperand(2);
22689- return DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, N->getValueType(0),
22690- N->getOperand(1), Input,
22691- DAG.getConstant(1, DL, Input.getValueType()));
22692- }
2269322550 case Intrinsic::aarch64_neon_vcvtfxs2fp:
2269422551 case Intrinsic::aarch64_neon_vcvtfxu2fp:
2269522552 return tryCombineFixedPointConvert(N, DCI, DAG);
0 commit comments