@@ -21363,42 +21363,8 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
2136321363 // must not be zext, volatile, indexed, and they must be consecutive.
2136421364 BaseIndexOffset LdBasePtr;
2136521365
21366- // Check if a call exists in the store chain.
21367- auto HasCallInLdStChain = [](SDNode *Load, SDNode *Store) {
21368- SmallPtrSet<const SDNode *, 32> Visited;
21369- SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
21370- Worklist.emplace_back(Store->getOperand(0).getNode(), false);
21371-
21372- while (!Worklist.empty()) {
21373- auto [Node, FoundCall] = Worklist.pop_back_val();
21374- if (!Visited.insert(Node).second || Node->getNumOperands() == 0)
21375- continue;
21376-
21377- switch (Node->getOpcode()) {
21378- case ISD::CALLSEQ_END:
21379- Worklist.emplace_back(Node->getOperand(0).getNode(), true);
21380- break;
21381- case ISD::TokenFactor:
21382- for (SDValue Op : Node->ops())
21383- Worklist.emplace_back(Op.getNode(), FoundCall);
21384- break;
21385- case ISD::LOAD:
21386- if (Node == Load)
21387- return FoundCall;
21388- [[fallthrough]];
21389- default:
21390- if (Node->getNumOperands() > 0)
21391- Worklist.emplace_back(Node->getOperand(0).getNode(), FoundCall);
21392- break;
21393- }
21394- }
21395- return false;
21396- };
21397-
21398- auto StIt = StoreNodes.begin();
21399- unsigned i = 0;
21400- while (StIt != StoreNodes.end() && i++ < NumConsecutiveStores) {
21401- StoreSDNode *St = cast<StoreSDNode>(StIt->MemNode);
21366+ for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21367+ StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
2140221368 SDValue Val = peekThroughBitcasts(St->getValue());
2140321369 LoadSDNode *Ld = cast<LoadSDNode>(Val);
2140421370
@@ -21414,14 +21380,8 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
2141421380 LdBasePtr = LdPtr;
2141521381 }
2141621382
21417- // Check if there is a call in the load/store chain.
21418- if (!TLI.shouldMergeStoreOfLoadsOverCall(MemVT) &&
21419- HasCallInLdStChain(Ld, St)) {
21420- StIt = StoreNodes.erase(StIt);
21421- } else {
21422- LoadNodes.push_back(MemOpLink(Ld, LdOffset));
21423- ++StIt;
21424- }
21383+ // We found a potential memory operand to merge.
21384+ LoadNodes.push_back(MemOpLink(Ld, LdOffset));
2142521385 }
2142621386
2142721387 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
@@ -21593,6 +21553,56 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
2159321553 JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
2159421554 }
2159521555
21556+ auto HasCallInLdStChain = [](SmallVectorImpl<MemOpLink> &StoreNodes,
21557+ SmallVectorImpl<MemOpLink> &LoadNodes,
21558+ unsigned NumStores) {
21559+ for (unsigned i = 0; i < NumStores; ++i) {
21560+ StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
21561+ SDValue Val = peekThroughBitcasts(St->getValue());
21562+ LoadSDNode *Ld = cast<LoadSDNode>(Val);
21563+ assert(Ld == LoadNodes[i].MemNode && "Load and store mismatch");
21564+
21565+ SmallPtrSet<const SDNode *, 32> Visited;
21566+ SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
21567+ Worklist.emplace_back(St->getOperand(0).getNode(), false);
21568+
21569+ while (!Worklist.empty()) {
21570+ auto [Node, FoundCall] = Worklist.pop_back_val();
21571+ if (!Visited.insert(Node).second || Node->getNumOperands() == 0)
21572+ continue;
21573+
21574+ switch (Node->getOpcode()) {
21575+ case ISD::CALLSEQ_END:
21576+ Worklist.emplace_back(Node->getOperand(0).getNode(), true);
21577+ break;
21578+ case ISD::TokenFactor:
21579+ for (SDValue Op : Node->ops())
21580+ Worklist.emplace_back(Op.getNode(), FoundCall);
21581+ break;
21582+ case ISD::LOAD:
21583+ if (Node == Ld)
21584+ return FoundCall;
21585+ [[fallthrough]];
21586+ default:
21587+ if (Node->getNumOperands() > 0)
21588+ Worklist.emplace_back(Node->getOperand(0).getNode(), FoundCall);
21589+ break;
21590+ }
21591+ }
21592+ return false;
21593+ }
21594+ return false;
21595+ };
21596+
21597+ // Check if there is a call in the load/store chain.
21598+ if (!TLI.shouldMergeStoreOfLoadsOverCall(JointMemOpVT) &&
21599+ HasCallInLdStChain(StoreNodes, LoadNodes, NumElem)) {
21600+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
21601+ LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
21602+ NumConsecutiveStores -= NumElem;
21603+ continue;
21604+ }
21605+
2159621606 SDLoc LoadDL(LoadNodes[0].MemNode);
2159721607 SDLoc StoreDL(StoreNodes[0].MemNode);
2159821608
0 commit comments