@@ -22514,138 +22514,6 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
2251422514 return SDValue();
2251522515}
2251622516
22517- SDValue tryLowerPartialReductionToDot(SDNode *N,
22518- const AArch64Subtarget *Subtarget,
22519- SelectionDAG &DAG) {
22520-
22521- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
22522- getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
22523- "Expected a partial reduction node");
22524-
22525- bool Scalable = N->getValueType(0).isScalableVector();
22526- if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
22527- return SDValue();
22528- if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
22529- return SDValue();
22530-
22531- SDLoc DL(N);
22532-
22533- SDValue Op2 = N->getOperand(2);
22534- unsigned Op2Opcode = Op2->getOpcode();
22535- SDValue MulOpLHS, MulOpRHS;
22536- bool MulOpLHSIsSigned, MulOpRHSIsSigned;
22537- if (ISD::isExtOpcode(Op2Opcode)) {
22538- MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
22539- MulOpLHS = Op2->getOperand(0);
22540- MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
22541- } else if (Op2Opcode == ISD::MUL) {
22542- SDValue ExtMulOpLHS = Op2->getOperand(0);
22543- SDValue ExtMulOpRHS = Op2->getOperand(1);
22544-
22545- unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
22546- unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
22547- if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
22548- !ISD::isExtOpcode(ExtMulOpRHSOpcode))
22549- return SDValue();
22550-
22551- MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
22552- MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
22553-
22554- MulOpLHS = ExtMulOpLHS->getOperand(0);
22555- MulOpRHS = ExtMulOpRHS->getOperand(0);
22556-
22557- if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
22558- return SDValue();
22559- } else
22560- return SDValue();
22561-
22562- SDValue Acc = N->getOperand(1);
22563- EVT ReducedVT = N->getValueType(0);
22564- EVT MulSrcVT = MulOpLHS.getValueType();
22565-
22566- // Dot products operate on chunks of four elements so there must be four times
22567- // as many elements in the wide type
22568- if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
22569- !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
22570- !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
22571- !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22572- !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22573- !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22574- return SDValue();
22575-
22576- // If the extensions are mixed, we should lower it to a usdot instead
22577- unsigned Opcode = 0;
22578- if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
22579- if (!Subtarget->hasMatMulInt8())
22580- return SDValue();
22581-
22582- bool Scalable = N->getValueType(0).isScalableVT();
22583- // There's no nxv2i64 version of usdot
22584- if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
22585- return SDValue();
22586-
22587- Opcode = AArch64ISD::USDOT;
22588- // USDOT expects the signed operand to be last
22589- if (!MulOpRHSIsSigned)
22590- std::swap(MulOpLHS, MulOpRHS);
22591- } else
22592- Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22593-
22594- // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
22595- // product followed by a zero / sign extension
22596- if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22597- (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22598- EVT ReducedVTI32 =
22599- (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22600-
22601- SDValue DotI32 =
22602- DAG.getNode(Opcode, DL, ReducedVTI32,
22603- DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
22604- SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
22605- return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
22606- }
22607-
22608- return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
22609- }
22610-
22611- SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
22612- const AArch64Subtarget *Subtarget,
22613- SelectionDAG &DAG) {
22614-
22615- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
22616- getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
22617- "Expected a partial reduction node");
22618-
22619- if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22620- return SDValue();
22621-
22622- SDLoc DL(N);
22623-
22624- if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
22625- return SDValue();
22626- SDValue Acc = N->getOperand(1);
22627- SDValue Ext = N->getOperand(2);
22628- EVT AccVT = Acc.getValueType();
22629- EVT ExtVT = Ext.getValueType();
22630- if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
22631- return SDValue();
22632-
22633- SDValue ExtOp = Ext->getOperand(0);
22634- EVT ExtOpVT = ExtOp.getValueType();
22635-
22636- if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22637- !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22638- !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22639- return SDValue();
22640-
22641- bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
22642- unsigned BottomOpcode =
22643- ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22644- unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22645- SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
22646- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
22647- }
22648-
2264922517static SDValue combineSVEBitSel(unsigned IID, SDNode *N, SelectionDAG &DAG) {
2265022518 SDLoc DL(N);
2265122519 EVT VT = N->getValueType(0);
@@ -22678,17 +22546,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
2267822546 switch (IID) {
2267922547 default:
2268022548 break;
22681- case Intrinsic::vector_partial_reduce_add: {
22682- if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
22683- return Dot;
22684- if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
22685- return WideAdd;
22686- SDLoc DL(N);
22687- SDValue Input = N->getOperand(2);
22688- return DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, N->getValueType(0),
22689- N->getOperand(1), Input,
22690- DAG.getConstant(1, DL, Input.getValueType()));
22691- }
2269222549 case Intrinsic::aarch64_neon_vcvtfxs2fp:
2269322550 case Intrinsic::aarch64_neon_vcvtfxu2fp:
2269422551 return tryCombineFixedPointConvert(N, DCI, DAG);
0 commit comments