@@ -240,26 +240,21 @@ struct DAE : public Pass {
240240 scanner.walkModuleCode (module );
241241 // Scan all the functions.
242242 scanner.run (getPassRunner (), module );
243- // Combine all the info.
244- struct CallContext {
245- Call* call;
246- Function* func;
247- };
248243
244+ // Combine all the info from the scan.
249245 std::vector<std::vector<Call*>> allCalls (numFunctions);
250246 std::vector<bool > tailCallees (numFunctions);
251247 std::vector<bool > hasUnseenCalls (numFunctions);
252248
253- // Track the function in which relevant expressions exist. When we modify
254- // those expressions we will need to mark the function's info as stale.
255- std::unordered_map<Expression*, Name> expressionFuncs;
249+ // For each function, the set of callers.
250+ std::vector<std::unordered_set<Name>> callers (numFunctions);
251+
256252 for (auto & [func, info] : infoMap) {
257253 for (auto & [name, calls] : info.calls ) {
258- auto & allCallsToName = allCalls[indexes[name]];
254+ auto targetIndex = indexes[name];
255+ auto & allCallsToName = allCalls[targetIndex];
259256 allCallsToName.insert (allCallsToName.end (), calls.begin (), calls.end ());
260- for (auto * call : calls) {
261- expressionFuncs[call] = func;
262- }
257+ callers[targetIndex].insert (func);
263258 }
264259 for (auto & callee : info.tailCallees ) {
265260 tailCallees[indexes[callee]] = true ;
@@ -305,9 +300,9 @@ struct DAE : public Pass {
305300 assert (func.is ());
306301 infoMap[func].markStale ();
307302 };
308- auto markCallersStale = [&](const std::vector<Call*>& calls ) {
309- for (auto * call : calls ) {
310- markStale (expressionFuncs[call] );
303+ auto markCallersStale = [&](Index index ) {
304+ for (auto caller : callers[index] ) {
305+ markStale (caller );
311306 }
312307 };
313308
@@ -339,7 +334,7 @@ struct DAE : public Pass {
339334 if (refineReturnTypes (func, calls, module )) {
340335 refinedReturnTypes = true ;
341336 markStale (name);
342- markCallersStale (calls );
337+ markCallersStale (index );
343338 }
344339 auto optimizedIndexes =
345340 ParamUtils::applyConstantValues ({func}, calls, {}, module );
@@ -382,7 +377,7 @@ struct DAE : public Pass {
382377 // Success!
383378 worthOptimizing.insert (func);
384379 markStale (name);
385- markCallersStale (calls );
380+ markCallersStale (index );
386381 }
387382 if (outcome == ParamUtils::RemovalOutcome::Failure) {
388383 callTargetsToLocalize.insert (name);
@@ -424,15 +419,15 @@ struct DAE : public Pass {
424419 }
425420 if (removeReturnValue (func.get (), calls, module )) {
426421 // We should optimize the callers.
427- for (auto * call : calls ) {
428- worthOptimizing.insert (module ->getFunction (expressionFuncs[call] ));
422+ for (auto caller : callers[index] ) {
423+ worthOptimizing.insert (module ->getFunction (caller ));
429424 }
430425 }
431426 // TODO Removing a drop may also open optimization opportunities in the
432427 // callers.
433428 worthOptimizing.insert (func.get ());
434429 markStale (name);
435- markCallersStale (calls );
430+ markCallersStale (index );
436431 }
437432 }
438433 if (!callTargetsToLocalize.empty ()) {
0 commit comments