@@ -54214,22 +54214,31 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
5421454214// cases.
5421554215static SDValue combinei64TruncSrlConstant(SDValue N, EVT VT, SelectionDAG &DAG,
5421654216 const SDLoc &DL) {
54217+ assert(N.getOpcode() == ISD::SRL && "Unknown shift opcode");
54218+ std::optional<uint64_t> ValidSrlConst = DAG.getValidShiftAmount(N);
54219+ if (!ValidSrlConst)
54220+ return SDValue();
54221+ uint64_t SrlConstVal = *ValidSrlConst;
5421754222
5421854223 SDValue Op = N.getOperand(0);
54219- APInt OpConst = Op.getConstantOperandAPInt(1);
54220- APInt SrlConst = N.getConstantOperandAPInt(1);
54221- uint64_t SrlConstVal = SrlConst.getZExtValue();
5422254224 unsigned Opcode = Op.getOpcode();
54225+ assert(VT == MVT::i32 && Op.getValueType() == MVT::i64 &&
54226+ "Illegal truncation types");
54227+
54228+ if ((Opcode != ISD::ADD && Opcode != ISD::OR && Opcode != ISD::XOR) ||
54229+ !isa<ConstantSDNode>(Op.getOperand(1)))
54230+ return SDValue();
54231+ const APInt &OpConst = Op.getConstantOperandAPInt(1);
5422354232
54224- if (SrlConst.ule(32) ||
54233+ if (SrlConstVal <= 32 ||
5422554234 (Opcode == ISD::ADD && OpConst.countr_zero() < SrlConstVal))
5422654235 return SDValue();
5422754236
5422854237 SDValue OpLhsSrl =
5422954238 DAG.getNode(ISD::SRL, DL, MVT::i64, Op.getOperand(0), N.getOperand(1));
5423054239 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, OpLhsSrl);
5423154240
54232- APInt NewOpConstVal = OpConst.lshr(SrlConst ).trunc(VT.getSizeInBits());
54241+ APInt NewOpConstVal = OpConst.lshr(SrlConstVal ).trunc(VT.getSizeInBits());
5423354242 SDValue NewOpConst = DAG.getConstant(NewOpConstVal, DL, VT);
5423454243 SDValue NewOpNode = DAG.getNode(Opcode, DL, VT, Trunc, NewOpConst);
5423554244
@@ -54285,20 +54294,8 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
5428554294 if (!Src.hasOneUse())
5428654295 return SDValue();
5428754296
54288- if (VT == MVT::i32 && SrcVT == MVT::i64 && SrcOpcode == ISD::SRL &&
54289- isa<ConstantSDNode>(Src.getOperand(1))) {
54290-
54291- unsigned SrcOpOpcode = Src.getOperand(0).getOpcode();
54292- if ((SrcOpOpcode != ISD::ADD && SrcOpOpcode != ISD::OR &&
54293- SrcOpOpcode != ISD::XOR) ||
54294- !isa<ConstantSDNode>(Src.getOperand(0).getOperand(1)))
54295- return SDValue();
54296-
54297- if (SDValue R = combinei64TruncSrlConstant(Src, VT, DAG, DL))
54298- return R;
54299-
54300- return SDValue();
54301- }
54297+ if (VT == MVT::i32 && SrcVT == MVT::i64 && SrcOpcode == ISD::SRL)
54298+ return combinei64TruncSrlConstant(Src, VT, DAG, DL);
5430254299
5430354300 if (!VT.isVector())
5430454301 return SDValue();
0 commit comments