@@ -2690,6 +2690,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26902690 MAKE_CASE(AArch64ISD::RSHRNB_I)
26912691 MAKE_CASE(AArch64ISD::CTTZ_ELTS)
26922692 MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
2693+ MAKE_CASE(AArch64ISD::URSHR_I_PRED)
26932694 }
26942695#undef MAKE_CASE
26952696 return nullptr;
@@ -2974,6 +2975,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
29742975static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
29752976static SDValue convertFixedMaskToScalableVector(SDValue Mask,
29762977 SelectionDAG &DAG);
2978+ static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
29772979static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
29782980 EVT VT);
29792981
@@ -13862,6 +13864,51 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
1386213864 return SDValue();
1386313865}
1386413866
13867+ // Check if we can we lower this SRL to a rounding shift instruction. ResVT is
13868+ // possibly a truncated type, it tells how many bits of the value are to be
13869+ // used.
13870+ static bool canLowerSRLToRoundingShiftForVT(SDValue Shift, EVT ResVT,
13871+ SelectionDAG &DAG,
13872+ unsigned &ShiftValue,
13873+ SDValue &RShOperand) {
13874+ if (Shift->getOpcode() != ISD::SRL)
13875+ return false;
13876+
13877+ EVT VT = Shift.getValueType();
13878+ assert(VT.isScalableVT());
13879+
13880+ auto ShiftOp1 =
13881+ dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
13882+ if (!ShiftOp1)
13883+ return false;
13884+
13885+ ShiftValue = ShiftOp1->getZExtValue();
13886+ if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
13887+ return false;
13888+
13889+ SDValue Add = Shift->getOperand(0);
13890+ if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
13891+ return false;
13892+
13893+ assert(ResVT.getScalarSizeInBits() <= VT.getScalarSizeInBits() &&
13894+ "ResVT must be truncated or same type as the shift.");
13895+ // Check if an overflow can lead to incorrect results.
13896+ uint64_t ExtraBits = VT.getScalarSizeInBits() - ResVT.getScalarSizeInBits();
13897+ if (ShiftValue > ExtraBits && !Add->getFlags().hasNoUnsignedWrap())
13898+ return false;
13899+
13900+ auto AddOp1 =
13901+ dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
13902+ if (!AddOp1)
13903+ return false;
13904+ uint64_t AddValue = AddOp1->getZExtValue();
13905+ if (AddValue != 1ULL << (ShiftValue - 1))
13906+ return false;
13907+
13908+ RShOperand = Add->getOperand(0);
13909+ return true;
13910+ }
13911+
1386513912SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1386613913 SelectionDAG &DAG) const {
1386713914 EVT VT = Op.getValueType();
@@ -13887,6 +13934,15 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1388713934 Op.getOperand(0), Op.getOperand(1));
1388813935 case ISD::SRA:
1388913936 case ISD::SRL:
13937+ if (VT.isScalableVector() && Subtarget->hasSVE2orSME()) {
13938+ SDValue RShOperand;
13939+ unsigned ShiftValue;
13940+ if (canLowerSRLToRoundingShiftForVT(Op, VT, DAG, ShiftValue, RShOperand))
13941+ return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, VT,
13942+ getPredicateForVector(DAG, DL, VT), RShOperand,
13943+ DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
13944+ }
13945+
1389013946 if (VT.isScalableVector() ||
1389113947 useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
1389213948 unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
@@ -17711,9 +17767,6 @@ static SDValue performReinterpretCastCombine(SDNode *N) {
1771117767
1771217768static SDValue performSVEAndCombine(SDNode *N,
1771317769 TargetLowering::DAGCombinerInfo &DCI) {
17714- if (DCI.isBeforeLegalizeOps())
17715- return SDValue();
17716-
1771717770 SelectionDAG &DAG = DCI.DAG;
1771817771 SDValue Src = N->getOperand(0);
1771917772 unsigned Opc = Src->getOpcode();
@@ -17769,6 +17822,9 @@ static SDValue performSVEAndCombine(SDNode *N,
1776917822 return DAG.getNode(Opc, DL, N->getValueType(0), And);
1777017823 }
1777117824
17825+ if (DCI.isBeforeLegalizeOps())
17826+ return SDValue();
17827+
1777217828 // If both sides of AND operations are i1 splat_vectors then
1777317829 // we can produce just i1 splat_vector as the result.
1777417830 if (isAllActivePredicate(DAG, N->getOperand(0)))
@@ -20216,6 +20272,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
2021620272 case Intrinsic::aarch64_sve_uqsub_x:
2021720273 return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
2021820274 N->getOperand(1), N->getOperand(2));
20275+ case Intrinsic::aarch64_sve_urshr:
20276+ return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
20277+ N->getOperand(1), N->getOperand(2), N->getOperand(3));
2021920278 case Intrinsic::aarch64_sve_asrd:
2022020279 return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
2022120280 N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20832,6 +20891,51 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
2083220891 return SDValue();
2083320892}
2083420893
20894+ static bool isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode *N) {
20895+ if (N->getOpcode() != AArch64ISD::UZP1)
20896+ return false;
20897+ SDValue Op0 = N->getOperand(0);
20898+ EVT SrcVT = Op0->getValueType(0);
20899+ EVT DstVT = N->getValueType(0);
20900+ return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv16i8) ||
20901+ (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv8i16) ||
20902+ (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv4i32);
20903+ }
20904+
20905+ // Try to combine rounding shifts where the operands come from an extend, and
20906+ // the result is truncated and combined into one vector.
20907+ // uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) -> urshr(X, C)
20908+ static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
20909+ assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
20910+ SDValue Op0 = N->getOperand(0);
20911+ SDValue Op1 = N->getOperand(1);
20912+ EVT ResVT = N->getValueType(0);
20913+
20914+ unsigned RshOpc = Op0.getOpcode();
20915+ if (RshOpc != AArch64ISD::RSHRNB_I)
20916+ return SDValue();
20917+
20918+ // Same op code and imm value?
20919+ SDValue ShiftValue = Op0.getOperand(1);
20920+ if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(1))
20921+ return SDValue();
20922+
20923+ // Same unextended operand value?
20924+ SDValue Lo = Op0.getOperand(0);
20925+ SDValue Hi = Op1.getOperand(0);
20926+ if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
20927+ Hi.getOpcode() != AArch64ISD::UUNPKHI)
20928+ return SDValue();
20929+ SDValue OrigArg = Lo.getOperand(0);
20930+ if (OrigArg != Hi.getOperand(0))
20931+ return SDValue();
20932+
20933+ SDLoc DL(N);
20934+ return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
20935+ getPredicateForVector(DAG, DL, ResVT), OrigArg,
20936+ ShiftValue);
20937+ }
20938+
2083520939// Try to simplify:
2083620940// t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
2083720941// t2 = nxv8i16 srl(t1, ShiftValue)
@@ -20844,9 +20948,7 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
2084420948static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
2084520949 const AArch64Subtarget *Subtarget) {
2084620950 EVT VT = Srl->getValueType(0);
20847-
20848- if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
20849- Srl->getOpcode() != ISD::SRL)
20951+ if (!VT.isScalableVector() || !Subtarget->hasSVE2())
2085020952 return SDValue();
2085120953
2085220954 EVT ResVT;
@@ -20859,29 +20961,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
2085920961 else
2086020962 return SDValue();
2086120963
20862- auto SrlOp1 =
20863- dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
20864- if (!SrlOp1)
20865- return SDValue();
20866- unsigned ShiftValue = SrlOp1->getZExtValue();
20867- if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
20868- return SDValue();
20869-
20870- SDValue Add = Srl->getOperand(0);
20871- if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
20872- return SDValue();
20873- auto AddOp1 =
20874- dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
20875- if (!AddOp1)
20876- return SDValue();
20877- uint64_t AddValue = AddOp1->getZExtValue();
20878- if (AddValue != 1ULL << (ShiftValue - 1))
20879- return SDValue();
20880-
2088120964 SDLoc DL(Srl);
20965+ unsigned ShiftValue;
20966+ SDValue RShOperand;
20967+ if (!canLowerSRLToRoundingShiftForVT(Srl, ResVT, DAG, ShiftValue, RShOperand))
20968+ return SDValue();
2088220969 SDValue Rshrnb = DAG.getNode(
2088320970 AArch64ISD::RSHRNB_I, DL, ResVT,
20884- {Add->getOperand(0) , DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
20971+ {RShOperand , DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
2088520972 return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
2088620973}
2088720974
@@ -20919,6 +21006,9 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
2091921006 }
2092021007 }
2092121008
21009+ if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
21010+ return Urshr;
21011+
2092221012 if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
2092321013 return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
2092421014
@@ -20949,6 +21039,19 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
2094921039 if (!IsLittleEndian)
2095021040 return SDValue();
2095121041
21042+ // uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
21043+ // Example:
21044+ // nxv4i32 = uzp1 bitcast(nxv4i32 x to nxv2i64), bitcast(nxv4i32 y to nxv2i64)
21045+ // to
21046+ // nxv4i32 = uzp1 nxv4i32 x, nxv4i32 y
21047+ if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
21048+ Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
21049+ if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
21050+ return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
21051+ Op1.getOperand(0));
21052+ }
21053+ }
21054+
2095221055 if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
2095321056 return SDValue();
2095421057
0 commit comments