@@ -4040,47 +4040,48 @@ SDValue AMDGPUTargetLowering::splitBinaryBitConstantOpImpl(
40404040SDValue AMDGPUTargetLowering::performShlCombine (SDNode *N,
40414041 DAGCombinerInfo &DCI) const {
40424042 EVT VT = N->getValueType (0 );
4043-
4044- ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand (1 ));
4045- if (!RHS)
4046- return SDValue ();
4047-
40484043 SDValue LHS = N->getOperand (0 );
4049- unsigned RHSVal = RHS->getZExtValue ();
4050- if (!RHSVal)
4051- return LHS;
4052-
4044+ SDValue RHS = N->getOperand (1 );
4045+ ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
40534046 SDLoc SL (N);
40544047 SelectionDAG &DAG = DCI.DAG ;
40554048
4056- switch (LHS->getOpcode ()) {
4057- default :
4058- break ;
4059- case ISD::ZERO_EXTEND:
4060- case ISD::SIGN_EXTEND:
4061- case ISD::ANY_EXTEND: {
4062- SDValue X = LHS->getOperand (0 );
4063-
4064- if (VT == MVT::i32 && RHSVal == 16 && X.getValueType () == MVT::i16 &&
4065- isOperationLegal (ISD::BUILD_VECTOR, MVT::v2i16)) {
4066- // Prefer build_vector as the canonical form if packed types are legal.
4067- // (shl ([asz]ext i16:x), 16 -> build_vector 0, x
4068- SDValue Vec = DAG.getBuildVector (MVT::v2i16, SL,
4069- { DAG.getConstant (0 , SL, MVT::i16 ), LHS->getOperand (0 ) });
4070- return DAG.getNode (ISD::BITCAST, SL, MVT::i32 , Vec);
4071- }
4049+ unsigned RHSVal;
4050+ if (CRHS) {
4051+ RHSVal = CRHS->getZExtValue ();
4052+ if (!RHSVal)
4053+ return LHS;
40724054
4073- // shl (ext x) => zext (shl x), if shift does not overflow int
4074- if (VT != MVT::i64 )
4075- break ;
4076- KnownBits Known = DAG.computeKnownBits (X);
4077- unsigned LZ = Known.countMinLeadingZeros ();
4078- if (LZ < RHSVal)
4055+ switch (LHS->getOpcode ()) {
4056+ default :
40794057 break ;
4080- EVT XVT = X.getValueType ();
4081- SDValue Shl = DAG.getNode (ISD::SHL, SL, XVT, X, SDValue (RHS, 0 ));
4082- return DAG.getZExtOrTrunc (Shl, SL, VT);
4083- }
4058+ case ISD::ZERO_EXTEND:
4059+ case ISD::SIGN_EXTEND:
4060+ case ISD::ANY_EXTEND: {
4061+ SDValue X = LHS->getOperand (0 );
4062+
4063+ if (VT == MVT::i32 && RHSVal == 16 && X.getValueType () == MVT::i16 &&
4064+ isOperationLegal (ISD::BUILD_VECTOR, MVT::v2i16)) {
4065+ // Prefer build_vector as the canonical form if packed types are legal.
4066+ // (shl ([asz]ext i16:x), 16 -> build_vector 0, x
4067+ SDValue Vec = DAG.getBuildVector (
4068+ MVT::v2i16, SL,
4069+ {DAG.getConstant (0 , SL, MVT::i16 ), LHS->getOperand (0 )});
4070+ return DAG.getNode (ISD::BITCAST, SL, MVT::i32 , Vec);
4071+ }
4072+
4073+ // shl (ext x) => zext (shl x), if shift does not overflow int
4074+ if (VT != MVT::i64 )
4075+ break ;
4076+ KnownBits Known = DAG.computeKnownBits (X);
4077+ unsigned LZ = Known.countMinLeadingZeros ();
4078+ if (LZ < RHSVal)
4079+ break ;
4080+ EVT XVT = X.getValueType ();
4081+ SDValue Shl = DAG.getNode (ISD::SHL, SL, XVT, X, SDValue (CRHS, 0 ));
4082+ return DAG.getZExtOrTrunc (Shl, SL, VT);
4083+ }
4084+ }
40844085 }
40854086
40864087 if (VT != MVT::i64 )
@@ -4091,18 +4092,34 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40914092 // On some subtargets, 64-bit shift is a quarter rate instruction. In the
40924093 // common case, splitting this into a move and a 32-bit shift is faster and
40934094 // the same code size.
4094- if (RHSVal < 32 )
4095+ EVT TargetType = VT.getHalfSizedIntegerVT (*DAG.getContext ());
4096+ EVT TargetVecPairType = EVT::getVectorVT (*DAG.getContext (), TargetType, 2 );
4097+ KnownBits Known = DAG.computeKnownBits (RHS);
4098+
4099+ if (Known.getMinValue ().getZExtValue () < TargetType.getSizeInBits ())
40954100 return SDValue ();
4101+ SDValue ShiftAmt;
40964102
4097- SDValue ShiftAmt = DAG.getConstant (RHSVal - 32 , SL, MVT::i32 );
4103+ if (CRHS) {
4104+ ShiftAmt =
4105+ DAG.getConstant (RHSVal - TargetType.getSizeInBits (), SL, TargetType);
4106+ } else {
4107+ SDValue truncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4108+ const SDValue ShiftMask =
4109+ DAG.getConstant (TargetType.getSizeInBits () - 1 , SL, TargetType);
4110+ // This AND instruction will clamp out of bounds shift values.
4111+ // It will also be removed during later instruction selection.
4112+ ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
4113+ }
40984114
4099- SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, MVT::i32 , LHS);
4100- SDValue NewShift = DAG.getNode (ISD::SHL, SL, MVT::i32 , Lo, ShiftAmt);
4115+ SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, TargetType, LHS);
4116+ SDValue NewShift =
4117+ DAG.getNode (ISD::SHL, SL, TargetType, Lo, ShiftAmt, N->getFlags ());
41014118
4102- const SDValue Zero = DAG.getConstant (0 , SL, MVT:: i32 );
4119+ const SDValue Zero = DAG.getConstant (0 , SL, TargetType );
41034120
4104- SDValue Vec = DAG.getBuildVector (MVT::v2i32 , SL, {Zero, NewShift});
4105- return DAG.getNode (ISD::BITCAST, SL, MVT:: i64 , Vec);
4121+ SDValue Vec = DAG.getBuildVector (TargetVecPairType , SL, {Zero, NewShift});
4122+ return DAG.getNode (ISD::BITCAST, SL, VT , Vec);
41064123}
41074124
41084125SDValue AMDGPUTargetLowering::performSraCombine (SDNode *N,
0 commit comments