Skip to content

Commit 1915eb8

Browse files
author
anoopkg6
committed
Incorporate code review suggestions.
1 parent 86abe9b commit 1915eb8

File tree

1 file changed

+89
-91
lines changed

1 file changed

+89
-91
lines changed

llvm/lib/Target/SystemZ/SystemZISelLowering.cpp

Lines changed: 89 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8735,30 +8735,26 @@ SDValue SystemZTargetLowering::combineSETCC(
87358735
}
87368736

87378737
static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
8738-
auto *N = Val.getNode();
8739-
if (!N)
8740-
return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
8741-
switch (N->getOpcode()) {
8738+
switch (Val.getOpcode()) {
87428739
default:
87438740
return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
87448741
case SystemZISD::IPM:
8745-
if (N->getOperand(0).getOpcode() == SystemZISD::CLC ||
8746-
N->getOperand(0).getOpcode() == SystemZISD::STRCMP)
8747-
return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP);
8748-
return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY);
8742+
if (Val.getOperand(0).getOpcode() == SystemZISD::CLC ||
8743+
Val.getOperand(0).getOpcode() == SystemZISD::STRCMP)
8744+
return std::make_pair(Val.getOperand(0), SystemZ::CCMASK_ICMP);
8745+
return std::make_pair(Val.getOperand(0), SystemZ::CCMASK_ANY);
87498746
case SystemZISD::SELECT_CCMASK: {
8750-
SDValue Op4CCReg = N->getOperand(4);
8751-
auto *Op4CCNode = Op4CCReg.getNode();
8752-
auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
8753-
if (!CCValid || !Op4CCNode)
8754-
return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
8755-
int CCValidVal = CCValid->getZExtValue();
8756-
if (Op4CCNode->getOpcode() == SystemZISD::ICMP ||
8757-
Op4CCNode->getOpcode() == SystemZISD::TM) {
8758-
auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0));
8747+
SDValue Op4CCReg = Val.getOperand(4);
8748+
if (Op4CCReg.getOpcode() == SystemZISD::ICMP ||
8749+
Op4CCReg.getOpcode() == SystemZISD::TM) {
8750+
auto [OpCC, OpCCValid] = findCCUse(Op4CCReg.getOperand(0));
87598751
if (OpCC != SDValue())
87608752
return std::make_pair(OpCC, OpCCValid);
87618753
}
8754+
auto *CCValid = dyn_cast<ConstantSDNode>(Val.getOperand(2));
8755+
if (!CCValid)
8756+
return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
8757+
int CCValidVal = CCValid->getZExtValue();
87628758
return std::make_pair(Op4CCReg, CCValidVal);
87638759
}
87648760
case ISD::ADD:
@@ -8768,10 +8764,10 @@ static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
87688764
case ISD::SHL:
87698765
case ISD::SRA:
87708766
case ISD::SRL:
8771-
auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0));
8767+
auto [Op0CC, Op0CCValid] = findCCUse(Val.getOperand(0));
87728768
if (Op0CC != SDValue())
87738769
return std::make_pair(Op0CC, Op0CCValid);
8774-
return findCCUse(N->getOperand(1));
8770+
return findCCUse(Val.getOperand(1));
87758771
}
87768772
}
87778773

@@ -8780,19 +8776,18 @@ static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
87808776

87818777
SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
87828778
SelectionDAG &DAG) {
8783-
auto *N = Val.getNode(), *CCNode = CC.getNode();
8784-
if (!N || !CCNode)
8779+
if (CC == SDValue())
87858780
return {};
8786-
SDLoc DL(N);
8787-
auto Opcode = N->getOpcode();
8781+
SDLoc DL(Val);
8782+
auto Opcode = Val.getOpcode();
87888783
switch (Opcode) {
87898784
default:
87908785
return {};
87918786
case ISD::Constant:
87928787
return {Val, Val, Val, Val};
87938788
case SystemZISD::IPM: {
8794-
auto *IPMOp0Node = N->getOperand(0).getNode();
8795-
if (!IPMOp0Node || IPMOp0Node != CCNode)
8789+
SDValue IPMOp0 = Val.getOperand(0);
8790+
if (IPMOp0 != CC)
87968791
return {};
87978792
SmallVector<SDValue, 4> ShiftedCCVals;
87988793
for (auto CC : {0, 1, 2, 3})
@@ -8801,12 +8796,10 @@ SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
88018796
return ShiftedCCVals;
88028797
}
88038798
case SystemZISD::SELECT_CCMASK: {
8804-
SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1);
8805-
auto *TrueOp = TrueVal.getNode();
8806-
auto *FalseOp = FalseVal.getNode();
8807-
auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
8808-
auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3));
8809-
if (!TrueOp || !FalseOp || !CCValid || !CCMask)
8799+
SDValue TrueVal = Val.getOperand(0), FalseVal = Val.getOperand(1);
8800+
auto *CCValid = dyn_cast<ConstantSDNode>(Val.getOperand(2));
8801+
auto *CCMask = dyn_cast<ConstantSDNode>(Val.getOperand(3));
8802+
if (!CCValid || !CCMask)
88108803
return {};
88118804

88128805
int CCValidVal = CCValid->getZExtValue();
@@ -8815,12 +8808,10 @@ SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
88158808
const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DAG);
88168809
if (TrueSDVals.empty() || FalseSDVals.empty())
88178810
return {};
8818-
SDValue Op4CCReg = N->getOperand(4);
8819-
auto *Op4CCNode = Op4CCReg.getNode();
8820-
if (Op4CCNode && Op4CCNode != CCNode)
8811+
SDValue Op4CCReg = Val.getOperand(4);
8812+
if (Op4CCReg != CC)
88218813
combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DAG);
8822-
Op4CCNode = Op4CCReg.getNode();
8823-
if (!Op4CCNode || Op4CCNode != CCNode)
8814+
if (Op4CCReg != CC)
88248815
return {};
88258816
SmallVector<SDValue, 4> MergedSDVals;
88268817
for (auto &CCVal : {0, 1, 2, 3})
@@ -8834,12 +8825,14 @@ SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
88348825
case ISD::OR:
88358826
case ISD::XOR:
88368827
case ISD::SRA:
8837-
if (!N->hasOneUse())
8828+
// Avoid introducing CC spills (because ADD/AND/OR/XOR/SRA
8829+
// would clobber CC).
8830+
if (!Val.hasOneUse())
88388831
return {};
88398832
[[fallthrough]];
88408833
case ISD::SHL:
88418834
case ISD::SRL:
8842-
SDValue Op0 = N->getOperand(0), Op1 = N->getOperand(1);
8835+
SDValue Op0 = Val.getOperand(0), Op1 = Val.getOperand(1);
88438836
const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, CC, DAG);
88448837
const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, CC, DAG);
88458838
if (Op0SDVals.empty() || Op1SDVals.empty())
@@ -8862,27 +8855,22 @@ static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
88628855
auto *CCNode = CCReg.getNode();
88638856
if (!CCNode)
88648857
return false;
8865-
const auto getAPIntSDVals = [](const SmallVector<SDValue, 4> &Vals) {
8866-
SmallVector<APInt, 4> APIntVals;
8867-
for (const auto &Val : Vals) {
8868-
auto *ConstValNode = dyn_cast<ConstantSDNode>(Val.getNode());
8869-
if (!ConstValNode)
8870-
return SmallVector<APInt, 4>();
8871-
APIntVals.emplace_back(ConstValNode->getAPIntValue());
8872-
}
8873-
return APIntVals;
8874-
};
88758858

88768859
if (CCNode->getOpcode() == SystemZISD::TM) {
88778860
if (CCValid != SystemZ::CCMASK_TM)
88788861
return false;
8879-
auto emulateTMCCMask = [](const APInt &Op0Val, const APInt &Op1Val) {
8880-
auto Result = Op0Val & Op1Val;
8881-
bool AllOnes = Result == Op1Val;
8862+
auto emulateTMCCMask = [](const SDValue &Op0Val, const SDValue &Op1Val) {
8863+
auto *Op0Node = dyn_cast<ConstantSDNode>(Op0Val.getNode());
8864+
auto *Op1Node = dyn_cast<ConstantSDNode>(Op1Val.getNode());
8865+
if (!Op0Node || !Op1Node)
8866+
return -1;
8867+
auto Op0APVal = Op0Node->getAPIntValue();
8868+
auto Op1APVal = Op1Node->getAPIntValue();
8869+
auto Result = Op0APVal & Op1APVal;
8870+
bool AllOnes = Result == Op1APVal;
88828871
bool AllZeros = Result == 0;
8883-
int MSBPos = Op1Val.countl_zero();
8884-
bool IsLeftMostBitSet = (Result & (1 << MSBPos)) != 0;
8885-
return AllOnes ? 3 : AllZeros ? 0 : IsLeftMostBitSet ? 2 : 1;
8872+
bool IsLeftMostBitSet = Result[Op1APVal.getActiveBits()] != 0;
8873+
return AllZeros ? 0 : AllOnes ? 3 : IsLeftMostBitSet ? 2 : 1;
88868874
};
88878875
SDValue Op0 = CCNode->getOperand(0);
88888876
SDValue Op1 = CCNode->getOperand(1);
@@ -8893,20 +8881,18 @@ static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
88938881
const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, Op0CC, DAG);
88948882
if (Op0SDVals.empty() || Op1SDVals.empty())
88958883
return false;
8896-
auto &&Op0APInts = getAPIntSDVals(Op0SDVals);
8897-
const auto &&Op1APInts = getAPIntSDVals(Op1SDVals);
8898-
if (Op0APInts.empty() || Op1APInts.empty())
8899-
return false;
89008884
SmallVector<int, 4> CCVals;
8901-
std::transform(Op0APInts.begin(), Op0APInts.end(), Op1APInts.begin(),
8885+
std::transform(Op0SDVals.begin(), Op0SDVals.end(), Op1SDVals.begin(),
89028886
std::back_inserter(CCVals), emulateTMCCMask);
8903-
if (CCVals.empty())
8887+
if (std::any_of(CCVals.begin(), CCVals.end(),
8888+
[](const auto Val) { return Val < 0; }))
89048889
return false;
89058890
int NewCCMask = 0;
89068891
for (auto CC : CCVals) {
89078892
NewCCMask <<= 1;
89088893
NewCCMask |= (CCMask & (1 << (3 - CC))) != 0;
89098894
}
8895+
NewCCMask &= Op0CCValid;
89108896
CCReg = Op0CC;
89118897
CCMask = NewCCMask;
89128898
CCValid = Op0CCValid;
@@ -8925,23 +8911,26 @@ static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
89258911
const auto &&Op1SDVals = simplifyAssumingCCVal(CmpOp1, Op0CC, DAG);
89268912
if (Op0SDVals.empty() || Op1SDVals.empty())
89278913
return false;
8928-
auto &&Op0APInts = getAPIntSDVals(Op0SDVals);
8929-
const auto &&Op1APInts = getAPIntSDVals(Op1SDVals);
8930-
if (Op0APInts.empty() || Op1APInts.empty())
8931-
return false;
89328914

89338915
auto *CmpType = dyn_cast<ConstantSDNode>(CmpOp2);
89348916
auto CmpTypeVal = CmpType->getZExtValue();
8935-
const auto compareCCSigned = [&CmpTypeVal](const APInt &Op0Val,
8936-
const APInt &Op1Val) {
8917+
const auto compareCCSigned = [&CmpTypeVal](const SDValue &Op0Val,
8918+
const SDValue &Op1Val) {
8919+
auto *Op0Node = dyn_cast<ConstantSDNode>(Op0Val.getNode());
8920+
auto *Op1Node = dyn_cast<ConstantSDNode>(Op1Val.getNode());
8921+
if (!Op0Node || !Op1Node)
8922+
return -1;
8923+
auto Op0APVal = Op0Node->getAPIntValue();
8924+
auto Op1APVal = Op1Node->getAPIntValue();
89378925
if (CmpTypeVal == SystemZICMP::SignedOnly)
8938-
return Op0Val == Op1Val ? 0 : Op0Val.slt(Op1Val) ? 1 : 2;
8939-
return Op0Val == Op1Val ? 0 : Op0Val.ult(Op1Val) ? 1 : 2;
8926+
return Op0APVal == Op1APVal ? 0 : Op0APVal.slt(Op1APVal) ? 1 : 2;
8927+
return Op0APVal == Op1APVal ? 0 : Op0APVal.ult(Op1APVal) ? 1 : 2;
89408928
};
89418929
SmallVector<int, 4> CCVals;
8942-
std::transform(Op0APInts.begin(), Op0APInts.end(), Op1APInts.begin(),
8930+
std::transform(Op0SDVals.begin(), Op0SDVals.end(), Op1SDVals.begin(),
89438931
std::back_inserter(CCVals), compareCCSigned);
8944-
if (CCVals.empty())
8932+
if (std::any_of(CCVals.begin(), CCVals.end(),
8933+
[](const auto Val) { return Val < 0; }))
89458934
return false;
89468935
int NewCCMask = 0;
89478936
for (auto CC : CCVals) {
@@ -9029,10 +9018,10 @@ SDValue SystemZTargetLowering::combineSELECT_CCMASK(
90299018
int CCMaskVal = CCMask->getZExtValue();
90309019
SDValue CCReg = N->getOperand(4);
90319020

9032-
if (!combineCCMask(CCReg, CCValidVal, CCMaskVal, DAG))
9033-
return SDValue();
9021+
bool UpdatedCCReg = combineCCMask(CCReg, CCValidVal, CCMaskVal, DAG);
90349022

9035-
// Handle TrueVal and FalseVal in outermost select_ccmask.
9023+
// Attempting to optimize TrueVal/FalseVal in outermost select_ccmask either
9024+
// with CCReg found by combineCCMask or original CCReg.
90369025
SDValue TrueVal = N->getOperand(0);
90379026
SDValue FalseVal = N->getOperand(1);
90389027
const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CCReg, DAG);
@@ -9043,34 +9032,43 @@ SDValue SystemZTargetLowering::combineSELECT_CCMASK(
90439032
// and directly create a new SELECT_CCMASK.
90449033
if (!TrueSDVals.empty() && !FalseSDVals.empty()) {
90459034
SmallVector<SDValue, 4> MergedSDVals;
9035+
CCMaskVal &= CCValidVal;
90469036
for (auto CC : {0, 1, 2, 3}) {
90479037
MergedSDVals.emplace_back(((CCMaskVal & (1 << (3 - CC))) != 0)
90489038
? TrueSDVals[CC]
90499039
: FalseSDVals[CC]);
90509040
}
90519041
SmallSet<SDValue, 4> MergedSDValsSet(MergedSDVals.begin(),
90529042
MergedSDVals.end());
9053-
if (MergedSDValsSet.size() != 2)
9054-
return SDValue();
9055-
auto BeginIt = MergedSDValsSet.begin();
9056-
SDValue NewTrueVal = *BeginIt, NewFalseVal = *next(BeginIt);
9057-
if (NewTrueVal == FalseVal || NewFalseVal == TrueVal)
9058-
std::swap(NewTrueVal, NewFalseVal);
9059-
int NewCCMask = 0;
9060-
for (auto CC : {0, 1, 2, 3}) {
9061-
NewCCMask <<= 1;
9062-
NewCCMask |= MergedSDVals[CC] == NewTrueVal;
9043+
if (MergedSDValsSet.size() == 1)
9044+
return *MergedSDValsSet.begin();
9045+
if (MergedSDValsSet.size() == 2) {
9046+
auto BeginIt = MergedSDValsSet.begin();
9047+
SDValue NewTrueVal = *BeginIt, NewFalseVal = *next(BeginIt);
9048+
if (NewTrueVal == FalseVal || NewFalseVal == TrueVal)
9049+
std::swap(NewTrueVal, NewFalseVal);
9050+
int NewCCMask = 0;
9051+
for (auto CC : {0, 1, 2, 3}) {
9052+
NewCCMask <<= 1;
9053+
NewCCMask |= MergedSDVals[CC] == NewTrueVal;
9054+
}
9055+
CCMaskVal = NewCCMask;
9056+
TrueVal = NewTrueVal;
9057+
FalseVal = NewFalseVal;
9058+
return DAG.getNode(
9059+
SystemZISD::SELECT_CCMASK, SDLoc(N), N->getValueType(0), TrueVal,
9060+
FalseVal, DAG.getTargetConstant(CCValidVal, SDLoc(N), MVT::i32),
9061+
DAG.getTargetConstant(CCMaskVal, SDLoc(N), MVT::i32), CCReg);
90639062
}
9064-
NewCCMask &= CCValidVal;
9065-
CCMaskVal = NewCCMask;
9066-
TrueVal = NewTrueVal;
9067-
FalseVal = NewFalseVal;
90689063
}
90699064

9070-
return DAG.getNode(
9071-
SystemZISD::SELECT_CCMASK, SDLoc(N), N->getValueType(0), TrueVal,
9072-
FalseVal, DAG.getTargetConstant(CCValidVal, SDLoc(N), MVT::i32),
9073-
DAG.getTargetConstant(CCMaskVal, SDLoc(N), MVT::i32), CCReg);
9065+
if (UpdatedCCReg)
9066+
return DAG.getNode(
9067+
SystemZISD::SELECT_CCMASK, SDLoc(N), N->getValueType(0), TrueVal,
9068+
FalseVal, DAG.getTargetConstant(CCValidVal, SDLoc(N), MVT::i32),
9069+
DAG.getTargetConstant(CCMaskVal, SDLoc(N), MVT::i32), CCReg);
9070+
9071+
return SDValue();
90749072
}
90759073

90769074
SDValue SystemZTargetLowering::combineGET_CCMASK(

0 commit comments

Comments
 (0)