@@ -94,6 +94,7 @@ using DAEFunctionInfoMap = std::unordered_map<Name, DAEFunctionInfo>;
94
94
struct DAEScanner
95
95
: public WalkerPass<PostWalker<DAEScanner, Visitor<DAEScanner>>> {
96
96
bool isFunctionParallel () override { return true ; }
97
+ bool modifiesBinaryenIR () override { return false ; }
97
98
98
99
std::unique_ptr<Pass> create () override {
99
100
return std::make_unique<DAEScanner>(infoMap);
@@ -188,6 +189,11 @@ struct DAE : public Pass {
188
189
189
190
bool optimize = false ;
190
191
192
+ Index numFunctions;
193
+
194
+ // Map of function names to indexes. This lets us use indexes below for speed.
195
+ std::unordered_map<Name, Index> indexes;
196
+
191
197
void run (Module* module ) override {
192
198
DAEFunctionInfoMap infoMap;
193
199
// Ensure all entries exist so the parallel threads don't modify the data
@@ -198,6 +204,12 @@ struct DAE : public Pass {
198
204
// The null name represents module-level code (not in a function).
199
205
infoMap[Name ()];
200
206
207
+ numFunctions = module ->functions .size ();
208
+
209
+ for (Index i = 0 ; i < numFunctions; i++) {
210
+ indexes[module ->functions [i]->name ] = i;
211
+ }
212
+
201
213
// Iterate to convergence.
202
214
while (1 ) {
203
215
if (!iteration (module , infoMap)) {
@@ -233,34 +245,36 @@ struct DAE : public Pass {
233
245
Call* call;
234
246
Function* func;
235
247
};
236
- std::map<Name, std::vector<Call*>> allCalls;
237
- std::unordered_set<Name> tailCallees;
238
- std::unordered_set<Name> hasUnseenCalls;
248
+
249
+ std::vector<std::vector<Call*>> allCalls (numFunctions);
250
+ std::vector<bool > tailCallees (numFunctions);
251
+ std::vector<bool > hasUnseenCalls (numFunctions);
252
+
239
253
// Track the function in which relevant expressions exist. When we modify
240
254
// those expressions we will need to mark the function's info as stale.
241
255
std::unordered_map<Expression*, Name> expressionFuncs;
242
256
for (auto & [func, info] : infoMap) {
243
257
for (auto & [name, calls] : info.calls ) {
244
- auto & allCallsToName = allCalls[name];
258
+ auto & allCallsToName = allCalls[indexes[ name] ];
245
259
allCallsToName.insert (allCallsToName.end (), calls.begin (), calls.end ());
246
260
for (auto * call : calls) {
247
261
expressionFuncs[call] = func;
248
262
}
249
263
}
250
264
for (auto & callee : info.tailCallees ) {
251
- tailCallees. insert ( callee) ;
265
+ tailCallees[indexes[ callee]] = true ;
252
266
}
253
267
for (auto & [call, dropp] : info.droppedCalls ) {
254
268
allDroppedCalls[call] = dropp;
255
269
}
256
270
for (auto & name : info.hasUnseenCalls ) {
257
- hasUnseenCalls. insert ( name) ;
271
+ hasUnseenCalls[indexes[ name]] = true ;
258
272
}
259
273
}
260
274
// Exports are considered unseen calls.
261
275
for (auto & curr : module ->exports ) {
262
276
if (curr->kind == ExternalKind::Function) {
263
- hasUnseenCalls. insert ( *curr->getInternalName ()) ;
277
+ hasUnseenCalls[indexes[ *curr->getInternalName ()]] = true ;
264
278
}
265
279
}
266
280
@@ -299,23 +313,32 @@ struct DAE : public Pass {
299
313
300
314
// We now have a mapping of all call sites for each function, and can look
301
315
// for optimization opportunities.
302
- for (auto & [name, calls] : allCalls) {
316
+ for (Index index = 0 ; index < numFunctions; index++) {
317
+ auto * func = module ->functions [index].get ();
318
+ if (func->imported ()) {
319
+ continue ;
320
+ }
303
321
// We can only optimize if we see all the calls and can modify them.
304
- if (hasUnseenCalls.count (name)) {
322
+ if (hasUnseenCalls[index]) {
323
+ continue ;
324
+ }
325
+ auto & calls = allCalls[index];
326
+ if (calls.empty ()) {
327
+ // Nothing calls this, so it is not worth optimizing.
305
328
continue ;
306
329
}
307
- auto * func = module ->getFunction (name);
308
330
// Refine argument types before doing anything else. This does not
309
331
// affect whether an argument is used or not, it just refines the type
310
332
// where possible.
333
+ auto name = func->name ;
311
334
if (refineArgumentTypes (func, calls, module , infoMap[name])) {
312
335
worthOptimizing.insert (func);
313
336
markStale (func->name );
314
337
}
315
338
// Refine return types as well.
316
339
if (refineReturnTypes (func, calls, module )) {
317
340
refinedReturnTypes = true ;
318
- markStale (func-> name );
341
+ markStale (name);
319
342
markCallersStale (calls);
320
343
}
321
344
auto optimizedIndexes =
@@ -336,21 +359,29 @@ struct DAE : public Pass {
336
359
ReFinalize ().run (getPassRunner (), module );
337
360
}
338
361
// We now know which parameters are unused, and can potentially remove them.
339
- for (auto & [name, calls] : allCalls) {
340
- if (hasUnseenCalls.count (name)) {
362
+ for (Index index = 0 ; index < numFunctions; index++) {
363
+ auto * func = module ->functions [index].get ();
364
+ if (func->imported ()) {
365
+ continue ;
366
+ }
367
+ if (hasUnseenCalls[index]) {
341
368
continue ;
342
369
}
343
- auto * func = module ->getFunction (name);
344
370
auto numParams = func->getNumParams ();
345
371
if (numParams == 0 ) {
346
372
continue ;
347
373
}
374
+ auto & calls = allCalls[index];
375
+ if (calls.empty ()) {
376
+ continue ;
377
+ }
378
+ auto name = func->name ;
348
379
auto [removedIndexes, outcome] = ParamUtils::removeParameters (
349
380
{func}, infoMap[name].unusedParams , calls, {}, module , getPassRunner ());
350
381
if (!removedIndexes.empty ()) {
351
382
// Success!
352
383
worthOptimizing.insert (func);
353
- markStale (func-> name );
384
+ markStale (name);
354
385
markCallersStale (calls);
355
386
}
356
387
if (outcome == ParamUtils::RemovalOutcome::Failure) {
@@ -362,25 +393,28 @@ struct DAE : public Pass {
362
393
// modified allCalls (we can't modify a call site twice in one iteration,
363
394
// once to remove a param, once to drop the return value).
364
395
if (worthOptimizing.empty ()) {
365
- for (auto & func : module ->functions ) {
396
+ for (Index index = 0 ; index < numFunctions; index++) {
397
+ auto & func = module ->functions [index];
398
+ if (func->imported ()) {
399
+ continue ;
400
+ }
366
401
if (func->getResults () == Type::none) {
367
402
continue ;
368
403
}
369
- auto name = func->name ;
370
- if (hasUnseenCalls.count (name)) {
404
+ if (hasUnseenCalls[index]) {
371
405
continue ;
372
406
}
407
+ auto name = func->name ;
373
408
if (infoMap[name].hasTailCalls ) {
374
409
continue ;
375
410
}
376
- if (tailCallees. count (name) ) {
411
+ if (tailCallees[index] ) {
377
412
continue ;
378
413
}
379
- auto iter = allCalls. find (name) ;
380
- if (iter == allCalls. end ()) {
414
+ auto & calls = allCalls[index] ;
415
+ if (calls. empty ()) {
381
416
continue ;
382
417
}
383
- auto & calls = iter->second ;
384
418
bool allDropped =
385
419
std::all_of (calls.begin (), calls.end (), [&](Call* call) {
386
420
return allDroppedCalls.count (call);
@@ -397,7 +431,7 @@ struct DAE : public Pass {
397
431
// TODO Removing a drop may also open optimization opportunities in the
398
432
// callers.
399
433
worthOptimizing.insert (func.get ());
400
- markStale (func-> name );
434
+ markStale (name);
401
435
markCallersStale (calls);
402
436
}
403
437
}
0 commit comments