Skip to content

Commit f27092f

Browse files
Check final type before we prevent merges
Signed-off-by: Mikhail R. Gadelha <[email protected]>
1 parent 04bca6d commit f27092f

File tree

1 file changed

+54
-44
lines changed

1 file changed

+54
-44
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21363,42 +21363,8 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
2136321363
// must not be zext, volatile, indexed, and they must be consecutive.
2136421364
BaseIndexOffset LdBasePtr;
2136521365

21366-
// Check if a call exists in the store chain.
21367-
auto HasCallInLdStChain = [](SDNode *Load, SDNode *Store) {
21368-
SmallPtrSet<const SDNode *, 32> Visited;
21369-
SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
21370-
Worklist.emplace_back(Store->getOperand(0).getNode(), false);
21371-
21372-
while (!Worklist.empty()) {
21373-
auto [Node, FoundCall] = Worklist.pop_back_val();
21374-
if (!Visited.insert(Node).second || Node->getNumOperands() == 0)
21375-
continue;
21376-
21377-
switch (Node->getOpcode()) {
21378-
case ISD::CALLSEQ_END:
21379-
Worklist.emplace_back(Node->getOperand(0).getNode(), true);
21380-
break;
21381-
case ISD::TokenFactor:
21382-
for (SDValue Op : Node->ops())
21383-
Worklist.emplace_back(Op.getNode(), FoundCall);
21384-
break;
21385-
case ISD::LOAD:
21386-
if (Node == Load)
21387-
return FoundCall;
21388-
[[fallthrough]];
21389-
default:
21390-
if (Node->getNumOperands() > 0)
21391-
Worklist.emplace_back(Node->getOperand(0).getNode(), FoundCall);
21392-
break;
21393-
}
21394-
}
21395-
return false;
21396-
};
21397-
21398-
auto StIt = StoreNodes.begin();
21399-
unsigned i = 0;
21400-
while (StIt != StoreNodes.end() && i++ < NumConsecutiveStores) {
21401-
StoreSDNode *St = cast<StoreSDNode>(StIt->MemNode);
21366+
for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21367+
StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
2140221368
SDValue Val = peekThroughBitcasts(St->getValue());
2140321369
LoadSDNode *Ld = cast<LoadSDNode>(Val);
2140421370

@@ -21414,14 +21380,8 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
2141421380
LdBasePtr = LdPtr;
2141521381
}
2141621382

21417-
// Check if there is a call in the load/store chain.
21418-
if (!TLI.shouldMergeStoreOfLoadsOverCall(MemVT) &&
21419-
HasCallInLdStChain(Ld, St)) {
21420-
StIt = StoreNodes.erase(StIt);
21421-
} else {
21422-
LoadNodes.push_back(MemOpLink(Ld, LdOffset));
21423-
++StIt;
21424-
}
21383+
// We found a potential memory operand to merge.
21384+
LoadNodes.push_back(MemOpLink(Ld, LdOffset));
2142521385
}
2142621386

2142721387
while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
@@ -21593,6 +21553,56 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
2159321553
JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
2159421554
}
2159521555

21556+
auto HasCallInLdStChain = [](SmallVectorImpl<MemOpLink> &StoreNodes,
21557+
SmallVectorImpl<MemOpLink> &LoadNodes,
21558+
unsigned NumStores) {
21559+
for (unsigned i = 0; i < NumStores; ++i) {
21560+
StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
21561+
SDValue Val = peekThroughBitcasts(St->getValue());
21562+
LoadSDNode *Ld = cast<LoadSDNode>(Val);
21563+
assert(Ld == LoadNodes[i].MemNode && "Load and store mismatch");
21564+
21565+
SmallPtrSet<const SDNode *, 32> Visited;
21566+
SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
21567+
Worklist.emplace_back(St->getOperand(0).getNode(), false);
21568+
21569+
while (!Worklist.empty()) {
21570+
auto [Node, FoundCall] = Worklist.pop_back_val();
21571+
if (!Visited.insert(Node).second || Node->getNumOperands() == 0)
21572+
continue;
21573+
21574+
switch (Node->getOpcode()) {
21575+
case ISD::CALLSEQ_END:
21576+
Worklist.emplace_back(Node->getOperand(0).getNode(), true);
21577+
break;
21578+
case ISD::TokenFactor:
21579+
for (SDValue Op : Node->ops())
21580+
Worklist.emplace_back(Op.getNode(), FoundCall);
21581+
break;
21582+
case ISD::LOAD:
21583+
if (Node == Ld)
21584+
return FoundCall;
21585+
[[fallthrough]];
21586+
default:
21587+
if (Node->getNumOperands() > 0)
21588+
Worklist.emplace_back(Node->getOperand(0).getNode(), FoundCall);
21589+
break;
21590+
}
21591+
}
21592+
return false;
21593+
}
21594+
return false;
21595+
};
21596+
21597+
// Check if there is a call in the load/store chain.
21598+
if (!TLI.shouldMergeStoreOfLoadsOverCall(JointMemOpVT) &&
21599+
HasCallInLdStChain(StoreNodes, LoadNodes, NumElem)) {
21600+
StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
21601+
LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
21602+
NumConsecutiveStores -= NumElem;
21603+
continue;
21604+
}
21605+
2159621606
SDLoc LoadDL(LoadNodes[0].MemNode);
2159721607
SDLoc StoreDL(StoreNodes[0].MemNode);
2159821608

0 commit comments

Comments
 (0)