@@ -4151,32 +4151,96 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
41514151
41524152SDValue AMDGPUTargetLowering::performSraCombine (SDNode *N,
41534153 DAGCombinerInfo &DCI) const {
4154- if (N->getValueType (0 ) != MVT::i64 )
4154+ SDValue RHS = N->getOperand (1 );
4155+ ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4156+ EVT VT = N->getValueType (0 );
4157+ SDValue LHS = N->getOperand (0 );
4158+ SelectionDAG &DAG = DCI.DAG ;
4159+ SDLoc SL (N);
4160+
4161+ if (VT.getScalarType () != MVT::i64 )
41554162 return SDValue ();
41564163
4157- const ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand (1 ));
4158- if (!RHS)
4164+ // For C >= 32
4165+ // i64 (sra x, C) -> (build_pair (sra hi_32(x), C - 32), sra hi_32(x), 31))
4166+
4167+ // On some subtargets, 64-bit shift is a quarter rate instruction. In the
4168+ // common case, splitting this into a move and a 32-bit shift is faster and
4169+ // the same code size.
4170+ KnownBits Known = DAG.computeKnownBits (RHS);
4171+
4172+ EVT ElementType = VT.getScalarType ();
4173+ EVT TargetScalarType = ElementType.getHalfSizedIntegerVT (*DAG.getContext ());
4174+ EVT TargetType = VT.isVector () ? VT.changeVectorElementType (TargetScalarType)
4175+ : TargetScalarType;
4176+
4177+ if (Known.getMinValue ().getZExtValue () < TargetScalarType.getSizeInBits ())
41594178 return SDValue ();
41604179
4161- SelectionDAG &DAG = DCI.DAG ;
4162- SDLoc SL (N);
4163- unsigned RHSVal = RHS->getZExtValue ();
4180+ SDValue ShiftFullAmt =
4181+ DAG.getConstant (TargetScalarType.getSizeInBits () - 1 , SL, TargetType);
4182+ SDValue ShiftAmt;
4183+ if (CRHS) {
4184+ unsigned RHSVal = CRHS->getZExtValue ();
4185+ ShiftAmt = DAG.getConstant (RHSVal - TargetScalarType.getSizeInBits (), SL,
4186+ TargetType);
4187+ } else if (Known.getMinValue ().getZExtValue () ==
4188+ (ElementType.getSizeInBits () - 1 )) {
4189+ ShiftAmt = ShiftFullAmt;
4190+ } else {
4191+ SDValue truncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4192+ const SDValue ShiftMask =
4193+ DAG.getConstant (TargetScalarType.getSizeInBits () - 1 , SL, TargetType);
4194+ // This AND instruction will clamp out of bounds shift values.
4195+ // It will also be removed during later instruction selection.
4196+ ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
4197+ }
41644198
4165- // For C >= 32
4166- // (sra i64:x, C) -> build_pair (sra hi_32(x), C - 32), (sra hi_32(x), 31)
4167- if (RHSVal >= 32 ) {
4168- SDValue Hi = getHiHalf64 (N->getOperand (0 ), DAG);
4169- Hi = DAG.getFreeze (Hi);
4170- SDValue HiShift = DAG.getNode (ISD::SRA, SL, MVT::i32 , Hi,
4171- DAG.getConstant (31 , SL, MVT::i32 ));
4172- SDValue LoShift = DAG.getNode (ISD::SRA, SL, MVT::i32 , Hi,
4173- DAG.getConstant (RHSVal - 32 , SL, MVT::i32 ));
4199+ EVT ConcatType;
4200+ SDValue Hi;
4201+ SDLoc LHSSL (LHS);
4202+ // Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
4203+ if (VT.isVector ()) {
4204+ unsigned NElts = TargetType.getVectorNumElements ();
4205+ ConcatType = TargetType.getDoubleNumVectorElementsVT (*DAG.getContext ());
4206+ SDValue SplitLHS = DAG.getNode (ISD::BITCAST, LHSSL, ConcatType, LHS);
4207+ SmallVector<SDValue, 8 > HiOps (NElts);
4208+ SmallVector<SDValue, 16 > HiAndLoOps;
41744209
4175- SDValue BuildVec = DAG.getBuildVector (MVT::v2i32, SL, {LoShift, HiShift});
4176- return DAG.getNode (ISD::BITCAST, SL, MVT::i64 , BuildVec);
4210+ DAG.ExtractVectorElements (SplitLHS, HiAndLoOps, 0 , NElts * 2 );
4211+ for (unsigned I = 0 ; I != NElts; ++I) {
4212+ HiOps[I] = HiAndLoOps[2 * I + 1 ];
4213+ }
4214+ Hi = DAG.getNode (ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
4215+ } else {
4216+ const SDValue One = DAG.getConstant (1 , LHSSL, TargetScalarType);
4217+ ConcatType = EVT::getVectorVT (*DAG.getContext (), TargetType, 2 );
4218+ SDValue SplitLHS = DAG.getNode (ISD::BITCAST, LHSSL, ConcatType, LHS);
4219+ Hi = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
41774220 }
4221+ Hi = DAG.getFreeze (Hi);
41784222
4179- return SDValue ();
4223+ SDValue HiShift = DAG.getNode (ISD::SRA, SL, TargetType, Hi, ShiftFullAmt);
4224+ SDValue NewShift = DAG.getNode (ISD::SRA, SL, TargetType, Hi, ShiftAmt);
4225+
4226+ SDValue Vec;
4227+ if (VT.isVector ()) {
4228+ unsigned NElts = TargetType.getVectorNumElements ();
4229+ SmallVector<SDValue, 8 > HiOps;
4230+ SmallVector<SDValue, 8 > LoOps;
4231+ SmallVector<SDValue, 16 > HiAndLoOps (NElts * 2 );
4232+
4233+ DAG.ExtractVectorElements (HiShift, HiOps, 0 , NElts);
4234+ DAG.ExtractVectorElements (NewShift, LoOps, 0 , NElts);
4235+ for (unsigned I = 0 ; I != NElts; ++I) {
4236+ HiAndLoOps[2 * I + 1 ] = HiOps[I];
4237+ HiAndLoOps[2 * I] = LoOps[I];
4238+ }
4239+ Vec = DAG.getNode (ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
4240+ } else {
4241+ Vec = DAG.getBuildVector (ConcatType, SL, {NewShift, HiShift});
4242+ }
4243+ return DAG.getNode (ISD::BITCAST, SL, VT, Vec);
41804244}
41814245
41824246SDValue AMDGPUTargetLowering::performSrlCombine (SDNode *N,
@@ -4213,7 +4277,7 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
42134277 return SDValue ();
42144278
42154279 // for C >= 32
4216- // i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
4280+ // i64 (srl x, C) -> (build_pair (srl hi_32(x), C - 32), 0)
42174281
42184282 // On some subtargets, 64-bit shift is a quarter rate instruction. In the
42194283 // common case, splitting this into a move and a 32-bit shift is faster and
@@ -5265,25 +5329,22 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
52655329 break ;
52665330 }
52675331 case ISD::SHL:
5332+ case ISD::SRA:
52685333 case ISD::SRL: {
52695334 // Range metadata can be invalidated when loads are converted to legal types
52705335 // (e.g. v2i64 -> v4i32).
5271- // Try to convert vector shl/srl before type legalization so that range
5336+ // Try to convert vector shl/sra/ srl before type legalization so that range
52725337 // metadata can be utilized.
52735338 if (!(N->getValueType (0 ).isVector () &&
52745339 DCI.getDAGCombineLevel () == BeforeLegalizeTypes) &&
52755340 DCI.getDAGCombineLevel () < AfterLegalizeDAG)
52765341 break ;
52775342 if (N->getOpcode () == ISD::SHL)
52785343 return performShlCombine (N, DCI);
5344+ if (N->getOpcode () == ISD::SRA)
5345+ return performSraCombine (N, DCI);
52795346 return performSrlCombine (N, DCI);
52805347 }
5281- case ISD::SRA: {
5282- if (DCI.getDAGCombineLevel () < AfterLegalizeDAG)
5283- break ;
5284-
5285- return performSraCombine (N, DCI);
5286- }
52875348 case ISD::TRUNCATE:
52885349 return performTruncateCombine (N, DCI);
52895350 case ISD::MUL:
0 commit comments