@@ -189,6 +189,11 @@ struct DAE : public Pass {
189189
190190 bool optimize = false ;
191191
192+ Index numFunctions;
193+
194+ // Map of function names to indexes. This lets us use indexes below for speed.
195+ std::unordered_map<Name, Index> indexes;
196+
192197 void run (Module* module ) override {
193198 DAEFunctionInfoMap infoMap;
194199 // Ensure all entries exist so the parallel threads don't modify the data
@@ -199,6 +204,12 @@ struct DAE : public Pass {
199204 // The null name represents module-level code (not in a function).
200205 infoMap[Name ()];
201206
207+ numFunctions = module ->functions .size ();
208+
209+ for (Index i = 0 ; i < numFunctions; i++) {
210+ indexes[module ->functions [i]->name ] = i;
211+ }
212+
202213 // Iterate to convergence.
203214 while (1 ) {
204215 if (!iteration (module , infoMap)) {
@@ -234,34 +245,36 @@ struct DAE : public Pass {
234245 Call* call;
235246 Function* func;
236247 };
237- std::map<Name, std::vector<Call*>> allCalls;
238- std::unordered_set<Name> tailCallees;
239- std::unordered_set<Name> hasUnseenCalls;
248+
249+ std::vector<std::vector<Call*>> allCalls (numFunctions);
250+ std::vector<bool > tailCallees (numFunctions);
251+ std::vector<bool > hasUnseenCalls (numFunctions);
252+
240253 // Track the function in which relevant expressions exist. When we modify
241254 // those expressions we will need to mark the function's info as stale.
242255 std::unordered_map<Expression*, Name> expressionFuncs;
243256 for (auto & [func, info] : infoMap) {
244257 for (auto & [name, calls] : info.calls ) {
245- auto & allCallsToName = allCalls[name];
258+ auto & allCallsToName = allCalls[indexes[ name] ];
246259 allCallsToName.insert (allCallsToName.end (), calls.begin (), calls.end ());
247260 for (auto * call : calls) {
248261 expressionFuncs[call] = func;
249262 }
250263 }
251264 for (auto & callee : info.tailCallees ) {
252- tailCallees. insert ( callee) ;
265+ tailCallees[indexes[ callee]] = true ;
253266 }
254267 for (auto & [call, dropp] : info.droppedCalls ) {
255268 allDroppedCalls[call] = dropp;
256269 }
257270 for (auto & name : info.hasUnseenCalls ) {
258- hasUnseenCalls. insert ( name) ;
271+ hasUnseenCalls[indexes[ name]] = true ;
259272 }
260273 }
261274 // Exports are considered unseen calls.
262275 for (auto & curr : module ->exports ) {
263276 if (curr->kind == ExternalKind::Function) {
264- hasUnseenCalls. insert ( *curr->getInternalName ()) ;
277+ hasUnseenCalls[indexes[ *curr->getInternalName ()]] = true ;
265278 }
266279 }
267280
@@ -300,23 +313,32 @@ struct DAE : public Pass {
300313
301314 // We now have a mapping of all call sites for each function, and can look
302315 // for optimization opportunities.
303- for (auto & [name, calls] : allCalls) {
316+ for (Index index = 0 ; index < numFunctions; index++) {
317+ auto * func = module ->functions [index].get ();
318+ if (func->imported ()) {
319+ continue ;
320+ }
304321 // We can only optimize if we see all the calls and can modify them.
305- if (hasUnseenCalls.count (name)) {
322+ if (hasUnseenCalls[index]) {
323+ continue ;
324+ }
325+ auto & calls = allCalls[index];
326+ if (calls.empty ()) {
327+ // Nothing calls this, so it is not worth optimizing.
306328 continue ;
307329 }
308- auto * func = module ->getFunction (name);
309330 // Refine argument types before doing anything else. This does not
310331 // affect whether an argument is used or not, it just refines the type
311332 // where possible.
333+ auto name = func->name ;
312334 if (refineArgumentTypes (func, calls, module , infoMap[name])) {
313335 worthOptimizing.insert (func);
314336 markStale (func->name );
315337 }
316338 // Refine return types as well.
317339 if (refineReturnTypes (func, calls, module )) {
318340 refinedReturnTypes = true ;
319- markStale (func-> name );
341+ markStale (name);
320342 markCallersStale (calls);
321343 }
322344 auto optimizedIndexes =
@@ -337,21 +359,29 @@ struct DAE : public Pass {
337359 ReFinalize ().run (getPassRunner (), module );
338360 }
339361 // We now know which parameters are unused, and can potentially remove them.
340- for (auto & [name, calls] : allCalls) {
341- if (hasUnseenCalls.count (name)) {
362+ for (Index index = 0 ; index < numFunctions; index++) {
363+ auto * func = module ->functions [index].get ();
364+ if (func->imported ()) {
365+ continue ;
366+ }
367+ if (hasUnseenCalls[index]) {
342368 continue ;
343369 }
344- auto * func = module ->getFunction (name);
345370 auto numParams = func->getNumParams ();
346371 if (numParams == 0 ) {
347372 continue ;
348373 }
374+ auto & calls = allCalls[index];
375+ if (calls.empty ()) {
376+ continue ;
377+ }
378+ auto name = func->name ;
349379 auto [removedIndexes, outcome] = ParamUtils::removeParameters (
350380 {func}, infoMap[name].unusedParams , calls, {}, module , getPassRunner ());
351381 if (!removedIndexes.empty ()) {
352382 // Success!
353383 worthOptimizing.insert (func);
354- markStale (func-> name );
384+ markStale (name);
355385 markCallersStale (calls);
356386 }
357387 if (outcome == ParamUtils::RemovalOutcome::Failure) {
@@ -363,25 +393,28 @@ struct DAE : public Pass {
363393 // modified allCalls (we can't modify a call site twice in one iteration,
364394 // once to remove a param, once to drop the return value).
365395 if (worthOptimizing.empty ()) {
366- for (auto & func : module ->functions ) {
396+ for (Index index = 0 ; index < numFunctions; index++) {
397+ auto & func = module ->functions [index];
398+ if (func->imported ()) {
399+ continue ;
400+ }
367401 if (func->getResults () == Type::none) {
368402 continue ;
369403 }
370- auto name = func->name ;
371- if (hasUnseenCalls.count (name)) {
404+ if (hasUnseenCalls[index]) {
372405 continue ;
373406 }
407+ auto name = func->name ;
374408 if (infoMap[name].hasTailCalls ) {
375409 continue ;
376410 }
377- if (tailCallees. count (name) ) {
411+ if (tailCallees[index] ) {
378412 continue ;
379413 }
380- auto iter = allCalls. find (name) ;
381- if (iter == allCalls. end ()) {
414+ auto & calls = allCalls[index] ;
415+ if (calls. empty ()) {
382416 continue ;
383417 }
384- auto & calls = iter->second ;
385418 bool allDropped =
386419 std::all_of (calls.begin (), calls.end (), [&](Call* call) {
387420 return allDroppedCalls.count (call);
@@ -398,7 +431,7 @@ struct DAE : public Pass {
398431 // TODO Removing a drop may also open optimization opportunities in the
399432 // callers.
400433 worthOptimizing.insert (func.get ());
401- markStale (func-> name );
434+ markStale (name);
402435 markCallersStale (calls);
403436 }
404437 }
0 commit comments