@@ -4097,7 +4097,7 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40974097 if (VT.getScalarType () != MVT::i64 )
40984098 return SDValue ();
40994099
4100- // i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
4100+ // i64 (shl x, C) -> (build_pair 0, (shl x, C - 32))
41014101
41024102 // On some subtargets, 64-bit shift is a quarter rate instruction. In the
41034103 // common case, splitting this into a move and a 32-bit shift is faster and
@@ -4117,12 +4117,12 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
41174117 ShiftAmt = DAG.getConstant (RHSVal - TargetScalarType.getSizeInBits (), SL,
41184118 TargetType);
41194119 } else {
4120- SDValue truncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4120+ SDValue TruncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
41214121 const SDValue ShiftMask =
41224122 DAG.getConstant (TargetScalarType.getSizeInBits () - 1 , SL, TargetType);
41234123 // This AND instruction will clamp out of bounds shift values.
41244124 // It will also be removed during later instruction selection.
4125- ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, truncShiftAmt , ShiftMask);
4125+ ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, TruncShiftAmt , ShiftMask);
41264126 }
41274127
41284128 SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, TargetType, LHS);
@@ -4181,50 +4181,105 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
41814181
41824182SDValue AMDGPUTargetLowering::performSrlCombine (SDNode *N,
41834183 DAGCombinerInfo &DCI) const {
4184- auto *RHS = dyn_cast<ConstantSDNode>(N->getOperand (1 ));
4185- if (!RHS)
4186- return SDValue ();
4187-
4184+ SDValue RHS = N->getOperand (1 );
4185+ ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
41884186 EVT VT = N->getValueType (0 );
41894187 SDValue LHS = N->getOperand (0 );
4190- unsigned ShiftAmt = RHS->getZExtValue ();
41914188 SelectionDAG &DAG = DCI.DAG ;
41924189 SDLoc SL (N);
4190+ unsigned RHSVal;
4191+
4192+ if (CRHS) {
4193+ RHSVal = CRHS->getZExtValue ();
41934194
4194- // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
4195- // this improves the ability to match BFE patterns in isel.
4196- if (LHS.getOpcode () == ISD::AND) {
4197- if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand (1 ))) {
4198- unsigned MaskIdx, MaskLen;
4199- if (Mask->getAPIntValue ().isShiftedMask (MaskIdx, MaskLen) &&
4200- MaskIdx == ShiftAmt) {
4201- return DAG.getNode (
4202- ISD::AND, SL, VT,
4203- DAG.getNode (ISD::SRL, SL, VT, LHS.getOperand (0 ), N->getOperand (1 )),
4204- DAG.getNode (ISD::SRL, SL, VT, LHS.getOperand (1 ), N->getOperand (1 )));
4195+ // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
4196+ // this improves the ability to match BFE patterns in isel.
4197+ if (LHS.getOpcode () == ISD::AND) {
4198+ if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand (1 ))) {
4199+ unsigned MaskIdx, MaskLen;
4200+ if (Mask->getAPIntValue ().isShiftedMask (MaskIdx, MaskLen) &&
4201+ MaskIdx == RHSVal) {
4202+ return DAG.getNode (ISD::AND, SL, VT,
4203+ DAG.getNode (ISD::SRL, SL, VT, LHS.getOperand (0 ),
4204+ N->getOperand (1 )),
4205+ DAG.getNode (ISD::SRL, SL, VT, LHS.getOperand (1 ),
4206+ N->getOperand (1 )));
4207+ }
42054208 }
42064209 }
42074210 }
42084211
4209- if (VT != MVT::i64 )
4212+ if (VT. getScalarType () != MVT::i64 )
42104213 return SDValue ();
42114214
4212- if (ShiftAmt < 32 )
4215+ // for C >= 32
4216+ // i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
4217+
4218+ // On some subtargets, 64-bit shift is a quarter rate instruction. In the
4219+ // common case, splitting this into a move and a 32-bit shift is faster and
4220+ // the same code size.
4221+ KnownBits Known = DAG.computeKnownBits (RHS);
4222+
4223+ EVT ElementType = VT.getScalarType ();
4224+ EVT TargetScalarType = ElementType.getHalfSizedIntegerVT (*DAG.getContext ());
4225+ EVT TargetType = VT.isVector () ? VT.changeVectorElementType (TargetScalarType)
4226+ : TargetScalarType;
4227+
4228+ if (Known.getMinValue ().getZExtValue () < TargetScalarType.getSizeInBits ())
42134229 return SDValue ();
42144230
4215- // srl i64:x, C for C >= 32
4216- // =>
4217- // build_pair (srl hi_32(x), C - 32), 0
4218- SDValue Zero = DAG.getConstant (0 , SL, MVT::i32 );
4231+ SDValue ShiftAmt;
4232+ if (CRHS) {
4233+ ShiftAmt = DAG.getConstant (RHSVal - TargetScalarType.getSizeInBits (), SL,
4234+ TargetType);
4235+ } else {
4236+ SDValue TruncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4237+ const SDValue ShiftMask =
4238+ DAG.getConstant (TargetScalarType.getSizeInBits () - 1 , SL, TargetType);
4239+ // This AND instruction will clamp out of bounds shift values.
4240+ // It will also be removed during later instruction selection.
4241+ ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, TruncShiftAmt, ShiftMask);
4242+ }
4243+
4244+ const SDValue Zero = DAG.getConstant (0 , SL, TargetScalarType);
4245+ EVT ConcatType;
4246+ SDValue Hi;
4247+ SDLoc LHSSL (LHS);
4248+ // Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
4249+ if (VT.isVector ()) {
4250+ unsigned NElts = TargetType.getVectorNumElements ();
4251+ ConcatType = TargetType.getDoubleNumVectorElementsVT (*DAG.getContext ());
4252+ SDValue SplitLHS = DAG.getNode (ISD::BITCAST, LHSSL, ConcatType, LHS);
4253+ SmallVector<SDValue, 8 > HiOps (NElts);
4254+ SmallVector<SDValue, 16 > HiAndLoOps;
42194255
4220- SDValue Hi = getHiHalf64 (LHS, DAG);
4256+ DAG.ExtractVectorElements (SplitLHS, HiAndLoOps, /* Start=*/ 0 , NElts * 2 );
4257+ for (unsigned I = 0 ; I != NElts; ++I)
4258+ HiOps[I] = HiAndLoOps[2 * I + 1 ];
4259+ Hi = DAG.getNode (ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
4260+ } else {
4261+ const SDValue One = DAG.getConstant (1 , LHSSL, TargetScalarType);
4262+ ConcatType = EVT::getVectorVT (*DAG.getContext (), TargetType, 2 );
4263+ SDValue SplitLHS = DAG.getNode (ISD::BITCAST, LHSSL, ConcatType, LHS);
4264+ Hi = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
4265+ }
42214266
4222- SDValue NewConst = DAG.getConstant (ShiftAmt - 32 , SL, MVT::i32 );
4223- SDValue NewShift = DAG.getNode (ISD::SRL, SL, MVT::i32 , Hi, NewConst);
4267+ SDValue NewShift = DAG.getNode (ISD::SRL, SL, TargetType, Hi, ShiftAmt);
42244268
4225- SDValue BuildPair = DAG.getBuildVector (MVT::v2i32, SL, {NewShift, Zero});
4269+ SDValue Vec;
4270+ if (VT.isVector ()) {
4271+ unsigned NElts = TargetType.getVectorNumElements ();
4272+ SmallVector<SDValue, 8 > LoOps;
4273+ SmallVector<SDValue, 16 > HiAndLoOps (NElts * 2 , Zero);
42264274
4227- return DAG.getNode (ISD::BITCAST, SL, MVT::i64 , BuildPair);
4275+ DAG.ExtractVectorElements (NewShift, LoOps, 0 , NElts);
4276+ for (unsigned I = 0 ; I != NElts; ++I)
4277+ HiAndLoOps[2 * I] = LoOps[I];
4278+ Vec = DAG.getNode (ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
4279+ } else {
4280+ Vec = DAG.getBuildVector (ConcatType, SL, {NewShift, Zero});
4281+ }
4282+ return DAG.getNode (ISD::BITCAST, SL, VT, Vec);
42284283}
42294284
42304285SDValue AMDGPUTargetLowering::performTruncateCombine (
@@ -5209,21 +5264,18 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
52095264
52105265 break ;
52115266 }
5212- case ISD::SHL: {
5267+ case ISD::SHL:
5268+ case ISD::SRL: {
52135269 // Range metadata can be invalidated when loads are converted to legal types
52145270 // (e.g. v2i64 -> v4i32).
5215- // Try to convert vector shl before type legalization so that range metadata
5216- // can be utilized.
5271+ // Try to convert vector shl/srl before type legalization so that range
5272+ // metadata can be utilized.
52175273 if (!(N->getValueType (0 ).isVector () &&
52185274 DCI.getDAGCombineLevel () == BeforeLegalizeTypes) &&
52195275 DCI.getDAGCombineLevel () < AfterLegalizeDAG)
52205276 break ;
5221- return performShlCombine (N, DCI);
5222- }
5223- case ISD::SRL: {
5224- if (DCI.getDAGCombineLevel () < AfterLegalizeDAG)
5225- break ;
5226-
5277+ if (N->getOpcode () == ISD::SHL)
5278+ return performShlCombine (N, DCI);
52275279 return performSrlCombine (N, DCI);
52285280 }
52295281 case ISD::SRA: {
0 commit comments