@@ -8735,30 +8735,26 @@ SDValue SystemZTargetLowering::combineSETCC(
87358735}
87368736
87378737static std::pair<SDValue, int > findCCUse (const SDValue &Val) {
8738- auto *N = Val.getNode ();
8739- if (!N)
8740- return std::make_pair (SDValue (), SystemZ::CCMASK_NONE);
8741- switch (N->getOpcode ()) {
8738+ switch (Val.getOpcode ()) {
87428739 default :
87438740 return std::make_pair (SDValue (), SystemZ::CCMASK_NONE);
87448741 case SystemZISD::IPM:
8745- if (N-> getOperand (0 ).getOpcode () == SystemZISD::CLC ||
8746- N-> getOperand (0 ).getOpcode () == SystemZISD::STRCMP)
8747- return std::make_pair (N-> getOperand (0 ), SystemZ::CCMASK_ICMP);
8748- return std::make_pair (N-> getOperand (0 ), SystemZ::CCMASK_ANY);
8742+ if (Val. getOperand (0 ).getOpcode () == SystemZISD::CLC ||
8743+ Val. getOperand (0 ).getOpcode () == SystemZISD::STRCMP)
8744+ return std::make_pair (Val. getOperand (0 ), SystemZ::CCMASK_ICMP);
8745+ return std::make_pair (Val. getOperand (0 ), SystemZ::CCMASK_ANY);
87498746 case SystemZISD::SELECT_CCMASK: {
8750- SDValue Op4CCReg = N->getOperand (4 );
8751- auto *Op4CCNode = Op4CCReg.getNode ();
8752- auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand (2 ));
8753- if (!CCValid || !Op4CCNode)
8754- return std::make_pair (SDValue (), SystemZ::CCMASK_NONE);
8755- int CCValidVal = CCValid->getZExtValue ();
8756- if (Op4CCNode->getOpcode () == SystemZISD::ICMP ||
8757- Op4CCNode->getOpcode () == SystemZISD::TM) {
8758- auto [OpCC, OpCCValid] = findCCUse (Op4CCNode->getOperand (0 ));
8747+ SDValue Op4CCReg = Val.getOperand (4 );
8748+ if (Op4CCReg.getOpcode () == SystemZISD::ICMP ||
8749+ Op4CCReg.getOpcode () == SystemZISD::TM) {
8750+ auto [OpCC, OpCCValid] = findCCUse (Op4CCReg.getOperand (0 ));
87598751 if (OpCC != SDValue ())
87608752 return std::make_pair (OpCC, OpCCValid);
87618753 }
8754+ auto *CCValid = dyn_cast<ConstantSDNode>(Val.getOperand (2 ));
8755+ if (!CCValid)
8756+ return std::make_pair (SDValue (), SystemZ::CCMASK_NONE);
8757+ int CCValidVal = CCValid->getZExtValue ();
87628758 return std::make_pair (Op4CCReg, CCValidVal);
87638759 }
87648760 case ISD::ADD:
@@ -8768,10 +8764,10 @@ static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
87688764 case ISD::SHL:
87698765 case ISD::SRA:
87708766 case ISD::SRL:
8771- auto [Op0CC, Op0CCValid] = findCCUse (N-> getOperand (0 ));
8767+ auto [Op0CC, Op0CCValid] = findCCUse (Val. getOperand (0 ));
87728768 if (Op0CC != SDValue ())
87738769 return std::make_pair (Op0CC, Op0CCValid);
8774- return findCCUse (N-> getOperand (1 ));
8770+ return findCCUse (Val. getOperand (1 ));
87758771 }
87768772}
87778773
@@ -8780,19 +8776,18 @@ static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
87808776
87818777SmallVector<SDValue, 4 > static simplifyAssumingCCVal (SDValue &Val, SDValue &CC,
87828778 SelectionDAG &DAG) {
8783- auto *N = Val.getNode (), *CCNode = CC.getNode ();
8784- if (!N || !CCNode)
8779+ if (CC == SDValue ())
87858780 return {};
8786- SDLoc DL (N );
8787- auto Opcode = N-> getOpcode ();
8781+ SDLoc DL (Val );
8782+ auto Opcode = Val. getOpcode ();
87888783 switch (Opcode) {
87898784 default :
87908785 return {};
87918786 case ISD::Constant:
87928787 return {Val, Val, Val, Val};
87938788 case SystemZISD::IPM: {
8794- auto *IPMOp0Node = N-> getOperand (0 ). getNode ( );
8795- if (!IPMOp0Node || IPMOp0Node != CCNode )
8789+ SDValue IPMOp0 = Val. getOperand (0 );
8790+ if (IPMOp0 != CC )
87968791 return {};
87978792 SmallVector<SDValue, 4 > ShiftedCCVals;
87988793 for (auto CC : {0 , 1 , 2 , 3 })
@@ -8801,12 +8796,10 @@ SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
88018796 return ShiftedCCVals;
88028797 }
88038798 case SystemZISD::SELECT_CCMASK: {
8804- SDValue TrueVal = N->getOperand (0 ), FalseVal = N->getOperand (1 );
8805- auto *TrueOp = TrueVal.getNode ();
8806- auto *FalseOp = FalseVal.getNode ();
8807- auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand (2 ));
8808- auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand (3 ));
8809- if (!TrueOp || !FalseOp || !CCValid || !CCMask)
8799+ SDValue TrueVal = Val.getOperand (0 ), FalseVal = Val.getOperand (1 );
8800+ auto *CCValid = dyn_cast<ConstantSDNode>(Val.getOperand (2 ));
8801+ auto *CCMask = dyn_cast<ConstantSDNode>(Val.getOperand (3 ));
8802+ if (!CCValid || !CCMask)
88108803 return {};
88118804
88128805 int CCValidVal = CCValid->getZExtValue ();
@@ -8815,12 +8808,10 @@ SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
88158808 const auto &&FalseSDVals = simplifyAssumingCCVal (FalseVal, CC, DAG);
88168809 if (TrueSDVals.empty () || FalseSDVals.empty ())
88178810 return {};
8818- SDValue Op4CCReg = N->getOperand (4 );
8819- auto *Op4CCNode = Op4CCReg.getNode ();
8820- if (Op4CCNode && Op4CCNode != CCNode)
8811+ SDValue Op4CCReg = Val.getOperand (4 );
8812+ if (Op4CCReg != CC)
88218813 combineCCMask (Op4CCReg, CCValidVal, CCMaskVal, DAG);
8822- Op4CCNode = Op4CCReg.getNode ();
8823- if (!Op4CCNode || Op4CCNode != CCNode)
8814+ if (Op4CCReg != CC)
88248815 return {};
88258816 SmallVector<SDValue, 4 > MergedSDVals;
88268817 for (auto &CCVal : {0 , 1 , 2 , 3 })
@@ -8834,12 +8825,14 @@ SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
88348825 case ISD::OR:
88358826 case ISD::XOR:
88368827 case ISD::SRA:
8837- if (!N->hasOneUse ())
8828+ // Avoid introducing CC spills (because ADD/AND/OR/XOR/SRA
8829+ // would clobber CC).
8830+ if (!Val.hasOneUse ())
88388831 return {};
88398832 [[fallthrough]];
88408833 case ISD::SHL:
88418834 case ISD::SRL:
8842- SDValue Op0 = N-> getOperand (0 ), Op1 = N-> getOperand (1 );
8835+ SDValue Op0 = Val. getOperand (0 ), Op1 = Val. getOperand (1 );
88438836 const auto &&Op0SDVals = simplifyAssumingCCVal (Op0, CC, DAG);
88448837 const auto &&Op1SDVals = simplifyAssumingCCVal (Op1, CC, DAG);
88458838 if (Op0SDVals.empty () || Op1SDVals.empty ())
@@ -8862,27 +8855,22 @@ static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
88628855 auto *CCNode = CCReg.getNode ();
88638856 if (!CCNode)
88648857 return false ;
8865- const auto getAPIntSDVals = [](const SmallVector<SDValue, 4 > &Vals) {
8866- SmallVector<APInt, 4 > APIntVals;
8867- for (const auto &Val : Vals) {
8868- auto *ConstValNode = dyn_cast<ConstantSDNode>(Val.getNode ());
8869- if (!ConstValNode)
8870- return SmallVector<APInt, 4 >();
8871- APIntVals.emplace_back (ConstValNode->getAPIntValue ());
8872- }
8873- return APIntVals;
8874- };
88758858
88768859 if (CCNode->getOpcode () == SystemZISD::TM) {
88778860 if (CCValid != SystemZ::CCMASK_TM)
88788861 return false ;
8879- auto emulateTMCCMask = [](const APInt &Op0Val, const APInt &Op1Val) {
8880- auto Result = Op0Val & Op1Val;
8881- bool AllOnes = Result == Op1Val;
8862+ auto emulateTMCCMask = [](const SDValue &Op0Val, const SDValue &Op1Val) {
8863+ auto *Op0Node = dyn_cast<ConstantSDNode>(Op0Val.getNode ());
8864+ auto *Op1Node = dyn_cast<ConstantSDNode>(Op1Val.getNode ());
8865+ if (!Op0Node || !Op1Node)
8866+ return -1 ;
8867+ auto Op0APVal = Op0Node->getAPIntValue ();
8868+ auto Op1APVal = Op1Node->getAPIntValue ();
8869+ auto Result = Op0APVal & Op1APVal;
8870+ bool AllOnes = Result == Op1APVal;
88828871 bool AllZeros = Result == 0 ;
8883- int MSBPos = Op1Val.countl_zero ();
8884- bool IsLeftMostBitSet = (Result & (1 << MSBPos)) != 0 ;
8885- return AllOnes ? 3 : AllZeros ? 0 : IsLeftMostBitSet ? 2 : 1 ;
8872+ bool IsLeftMostBitSet = Result[Op1APVal.getActiveBits ()] != 0 ;
8873+ return AllZeros ? 0 : AllOnes ? 3 : IsLeftMostBitSet ? 2 : 1 ;
88868874 };
88878875 SDValue Op0 = CCNode->getOperand (0 );
88888876 SDValue Op1 = CCNode->getOperand (1 );
@@ -8893,20 +8881,18 @@ static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
88938881 const auto &&Op1SDVals = simplifyAssumingCCVal (Op1, Op0CC, DAG);
88948882 if (Op0SDVals.empty () || Op1SDVals.empty ())
88958883 return false ;
8896- auto &&Op0APInts = getAPIntSDVals (Op0SDVals);
8897- const auto &&Op1APInts = getAPIntSDVals (Op1SDVals);
8898- if (Op0APInts.empty () || Op1APInts.empty ())
8899- return false ;
89008884 SmallVector<int , 4 > CCVals;
8901- std::transform (Op0APInts .begin (), Op0APInts .end (), Op1APInts .begin (),
8885+ std::transform (Op0SDVals .begin (), Op0SDVals .end (), Op1SDVals .begin (),
89028886 std::back_inserter (CCVals), emulateTMCCMask);
8903- if (CCVals.empty ())
8887+ if (std::any_of (CCVals.begin (), CCVals.end (),
8888+ [](const auto Val) { return Val < 0 ; }))
89048889 return false ;
89058890 int NewCCMask = 0 ;
89068891 for (auto CC : CCVals) {
89078892 NewCCMask <<= 1 ;
89088893 NewCCMask |= (CCMask & (1 << (3 - CC))) != 0 ;
89098894 }
8895+ NewCCMask &= Op0CCValid;
89108896 CCReg = Op0CC;
89118897 CCMask = NewCCMask;
89128898 CCValid = Op0CCValid;
@@ -8925,23 +8911,26 @@ static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
89258911 const auto &&Op1SDVals = simplifyAssumingCCVal (CmpOp1, Op0CC, DAG);
89268912 if (Op0SDVals.empty () || Op1SDVals.empty ())
89278913 return false ;
8928- auto &&Op0APInts = getAPIntSDVals (Op0SDVals);
8929- const auto &&Op1APInts = getAPIntSDVals (Op1SDVals);
8930- if (Op0APInts.empty () || Op1APInts.empty ())
8931- return false ;
89328914
89338915 auto *CmpType = dyn_cast<ConstantSDNode>(CmpOp2);
89348916 auto CmpTypeVal = CmpType->getZExtValue ();
8935- const auto compareCCSigned = [&CmpTypeVal](const APInt &Op0Val,
8936- const APInt &Op1Val) {
8917+ const auto compareCCSigned = [&CmpTypeVal](const SDValue &Op0Val,
8918+ const SDValue &Op1Val) {
8919+ auto *Op0Node = dyn_cast<ConstantSDNode>(Op0Val.getNode ());
8920+ auto *Op1Node = dyn_cast<ConstantSDNode>(Op1Val.getNode ());
8921+ if (!Op0Node || !Op1Node)
8922+ return -1 ;
8923+ auto Op0APVal = Op0Node->getAPIntValue ();
8924+ auto Op1APVal = Op1Node->getAPIntValue ();
89378925 if (CmpTypeVal == SystemZICMP::SignedOnly)
8938- return Op0Val == Op1Val ? 0 : Op0Val .slt (Op1Val ) ? 1 : 2 ;
8939- return Op0Val == Op1Val ? 0 : Op0Val .ult (Op1Val ) ? 1 : 2 ;
8926+ return Op0APVal == Op1APVal ? 0 : Op0APVal .slt (Op1APVal ) ? 1 : 2 ;
8927+ return Op0APVal == Op1APVal ? 0 : Op0APVal .ult (Op1APVal ) ? 1 : 2 ;
89408928 };
89418929 SmallVector<int , 4 > CCVals;
8942- std::transform (Op0APInts .begin (), Op0APInts .end (), Op1APInts .begin (),
8930+ std::transform (Op0SDVals .begin (), Op0SDVals .end (), Op1SDVals .begin (),
89438931 std::back_inserter (CCVals), compareCCSigned);
8944- if (CCVals.empty ())
8932+ if (std::any_of (CCVals.begin (), CCVals.end (),
8933+ [](const auto Val) { return Val < 0 ; }))
89458934 return false ;
89468935 int NewCCMask = 0 ;
89478936 for (auto CC : CCVals) {
@@ -9029,10 +9018,10 @@ SDValue SystemZTargetLowering::combineSELECT_CCMASK(
90299018 int CCMaskVal = CCMask->getZExtValue ();
90309019 SDValue CCReg = N->getOperand (4 );
90319020
9032- if (!combineCCMask (CCReg, CCValidVal, CCMaskVal, DAG))
9033- return SDValue ();
9021+ bool UpdatedCCReg = combineCCMask (CCReg, CCValidVal, CCMaskVal, DAG);
90349022
9035- // Handle TrueVal and FalseVal in outermost select_ccmask.
9023+ // Attempting to optimize TrueVal/FalseVal in outermost select_ccmask either
9024+ // with CCReg found by combineCCMask or original CCReg.
90369025 SDValue TrueVal = N->getOperand (0 );
90379026 SDValue FalseVal = N->getOperand (1 );
90389027 const auto &&TrueSDVals = simplifyAssumingCCVal (TrueVal, CCReg, DAG);
@@ -9043,34 +9032,43 @@ SDValue SystemZTargetLowering::combineSELECT_CCMASK(
90439032 // and directly create a new SELECT_CCMASK.
90449033 if (!TrueSDVals.empty () && !FalseSDVals.empty ()) {
90459034 SmallVector<SDValue, 4 > MergedSDVals;
9035+ CCMaskVal &= CCValidVal;
90469036 for (auto CC : {0 , 1 , 2 , 3 }) {
90479037 MergedSDVals.emplace_back (((CCMaskVal & (1 << (3 - CC))) != 0 )
90489038 ? TrueSDVals[CC]
90499039 : FalseSDVals[CC]);
90509040 }
90519041 SmallSet<SDValue, 4 > MergedSDValsSet (MergedSDVals.begin (),
90529042 MergedSDVals.end ());
9053- if (MergedSDValsSet.size () != 2 )
9054- return SDValue ();
9055- auto BeginIt = MergedSDValsSet.begin ();
9056- SDValue NewTrueVal = *BeginIt, NewFalseVal = *next (BeginIt);
9057- if (NewTrueVal == FalseVal || NewFalseVal == TrueVal)
9058- std::swap (NewTrueVal, NewFalseVal);
9059- int NewCCMask = 0 ;
9060- for (auto CC : {0 , 1 , 2 , 3 }) {
9061- NewCCMask <<= 1 ;
9062- NewCCMask |= MergedSDVals[CC] == NewTrueVal;
9043+ if (MergedSDValsSet.size () == 1 )
9044+ return *MergedSDValsSet.begin ();
9045+ if (MergedSDValsSet.size () == 2 ) {
9046+ auto BeginIt = MergedSDValsSet.begin ();
9047+ SDValue NewTrueVal = *BeginIt, NewFalseVal = *next (BeginIt);
9048+ if (NewTrueVal == FalseVal || NewFalseVal == TrueVal)
9049+ std::swap (NewTrueVal, NewFalseVal);
9050+ int NewCCMask = 0 ;
9051+ for (auto CC : {0 , 1 , 2 , 3 }) {
9052+ NewCCMask <<= 1 ;
9053+ NewCCMask |= MergedSDVals[CC] == NewTrueVal;
9054+ }
9055+ CCMaskVal = NewCCMask;
9056+ TrueVal = NewTrueVal;
9057+ FalseVal = NewFalseVal;
9058+ return DAG.getNode (
9059+ SystemZISD::SELECT_CCMASK, SDLoc (N), N->getValueType (0 ), TrueVal,
9060+ FalseVal, DAG.getTargetConstant (CCValidVal, SDLoc (N), MVT::i32 ),
9061+ DAG.getTargetConstant (CCMaskVal, SDLoc (N), MVT::i32 ), CCReg);
90639062 }
9064- NewCCMask &= CCValidVal;
9065- CCMaskVal = NewCCMask;
9066- TrueVal = NewTrueVal;
9067- FalseVal = NewFalseVal;
90689063 }
90699064
9070- return DAG.getNode (
9071- SystemZISD::SELECT_CCMASK, SDLoc (N), N->getValueType (0 ), TrueVal,
9072- FalseVal, DAG.getTargetConstant (CCValidVal, SDLoc (N), MVT::i32 ),
9073- DAG.getTargetConstant (CCMaskVal, SDLoc (N), MVT::i32 ), CCReg);
9065+ if (UpdatedCCReg)
9066+ return DAG.getNode (
9067+ SystemZISD::SELECT_CCMASK, SDLoc (N), N->getValueType (0 ), TrueVal,
9068+ FalseVal, DAG.getTargetConstant (CCValidVal, SDLoc (N), MVT::i32 ),
9069+ DAG.getTargetConstant (CCMaskVal, SDLoc (N), MVT::i32 ), CCReg);
9070+
9071+ return SDValue ();
90749072}
90759073
90769074SDValue SystemZTargetLowering::combineGET_CCMASK (
0 commit comments