@@ -4050,16 +4050,19 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40504050 // shl i64 X, Y -> [0, shl i32 X, (Y & 0x1F)]
40514051 if (VT == MVT::i64 ) {
40524052 KnownBits Known = DAG.computeKnownBits (RHS);
4053- if (Known.getMinValue ().getZExtValue () >= 32 ) {
4054- SDValue truncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, MVT::i32 , RHS);
4055- const SDValue C31 = DAG.getConstant (31 , SL, MVT::i32 );
4053+ EVT TargetType=VT.getHalfSizedIntegerVT (*DAG.getContext ());
4054+ EVT TargetVecPairType=EVT::getVectorVT (*DAG.getContext (), TargetType, 2 );
4055+
4056+ if (Known.getMinValue ().getZExtValue () >= TargetType.getSizeInBits ()) {
4057+ SDValue truncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4058+ const SDValue ShiftMask = DAG.getConstant (TargetType.getSizeInBits () - 1 , SL, TargetType);
40564059 SDValue MaskedShiftAmt =
4057- DAG.getNode (ISD::AND, SL, MVT:: i32 , truncShiftAmt, C31 );
4058- SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, MVT:: i32 , LHS);
4059- SDValue NewShift = DAG.getNode (ISD::SHL, SL, MVT:: i32 , Lo, MaskedShiftAmt);
4060- const SDValue Zero = DAG.getConstant (0 , SL, MVT:: i32 );
4061- SDValue Vec = DAG.getBuildVector (MVT::v2i32 , SL, {Zero, NewShift});
4062- return DAG.getNode (ISD::BITCAST, SL, MVT:: i64 , Vec);
4060+ DAG.getNode (ISD::AND, SL, TargetType , truncShiftAmt, ShiftMask );
4061+ SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, TargetType , LHS);
4062+ SDValue NewShift = DAG.getNode (ISD::SHL, SL, TargetType , Lo, MaskedShiftAmt);
4063+ const SDValue Zero = DAG.getConstant (0 , SL, TargetType );
4064+ SDValue Vec = DAG.getBuildVector (TargetVecPairType , SL, {Zero, NewShift});
4065+ return DAG.getNode (ISD::BITCAST, SL, VT , Vec);
40634066 }
40644067 }
40654068 return SDValue ();
0 commit comments