2727// If so, we can avoid even sending and receiving it. (Note how if
2828// the previous point was true for an argument, then the second
2929// must as well.)
30+ // * Find return values ("return arguments" ;) that are never used.
3031//
3132// This pass does not depend on flattening, but it may be more effective,
3233// as then call arguments never have side effects (which we need to
@@ -53,6 +54,9 @@ struct DAEFunctionInfo {
5354 SortedVector unusedParams;
5455 // Maps a function name to the calls going to it.
5556 std::unordered_map<Name, std::vector<Call*>> calls;
57+ // Map of all calls that are dropped, to their drops' locations (so that
58+ // if we can optimize out the drop, we can replace the drop there).
59+ std::unordered_map<Call*, Expression**> droppedCalls;
5660 // Whether the function can be called from places that
5761 // affect what we can do. For now, any call we don't
5862 // see inhibits our optimizations, but TODO: an export
@@ -116,6 +120,12 @@ struct DAEScanner : public WalkerPass<CFGWalker<DAEScanner, Visitor<DAEScanner>,
116120 }
117121 }
118122
123+ void visitDrop (Drop* curr) {
124+ if (auto * call = curr->value ->dynCast <Call>()) {
125+ info->droppedCalls [call] = getCurrentPointer ();
126+ }
127+ }
128+
119129 // main entry point
120130
121131 void doWalkFunction (Function* func) {
@@ -197,6 +207,15 @@ struct DAE : public Pass {
197207 bool optimize = false ;
198208
199209 void run (PassRunner* runner, Module* module ) override {
210+ // Iterate to convergence.
211+ while (1 ) {
212+ if (!iteration (runner, module )) {
213+ break ;
214+ }
215+ }
216+ }
217+
218+ bool iteration (PassRunner* runner, Module* module ) {
200219 DAEFunctionInfoMap infoMap;
201220 // Ensure they all exist so the parallel threads don't modify the data structure.
202221 ModuleUtils::iterDefinedFunctions (*module , [&](Function* func) {
@@ -230,14 +249,19 @@ struct DAE : public Pass {
230249 auto & allCallsToName = allCalls[name];
231250 allCallsToName.insert (allCallsToName.end (), calls.begin (), calls.end ());
232251 }
252+ for (auto & pair : info.droppedCalls ) {
253+ allDroppedCalls[pair.first ] = pair.second ;
254+ }
233255 }
234256 // We now have a mapping of all call sites for each function. Check which
235257 // are always passed the same constant for a particular argument.
236258 for (auto & pair : allCalls) {
237259 auto name = pair.first ;
238260 // We can only optimize if we see all the calls and can modify
239261 // them.
240- if (infoMap[name].hasUnseenCalls ) continue ;
262+ if (infoMap[name].hasUnseenCalls ) {
263+ continue ;
264+ }
241265 auto & calls = pair.second ;
242266 auto * func = module ->getFunction (name);
243267 auto numParams = func->getNumParams ();
@@ -311,13 +335,48 @@ struct DAE : public Pass {
311335 i--;
312336 }
313337 }
314- if (optimize && changed.size () > 0 ) {
338+ // We can also tell which calls have all their return values dropped. Note that we can't do this
339+ // if we changed anything so far, as we may have modified allCalls (we can't modify a call site
340+ // twice in one iteration, once to remove a param, once to drop the return value).
341+ if (changed.empty ()) {
342+ for (auto & func : module ->functions ) {
343+ if (func->result == none) {
344+ continue ;
345+ }
346+ auto name = func->name ;
347+ if (infoMap[name].hasUnseenCalls ) {
348+ continue ;
349+ }
350+ auto iter = allCalls.find (name);
351+ if (iter == allCalls.end ()) {
352+ continue ;
353+ }
354+ auto & calls = iter->second ;
355+ bool allDropped = true ;
356+ for (auto * call : calls) {
357+ if (!allDroppedCalls.count (call)) {
358+ allDropped = false ;
359+ break ;
360+ }
361+ }
362+ if (!allDropped) {
363+ continue ;
364+ }
365+ removeReturnValue (func.get (), calls, module );
366+ // TODO Removing a drop may also open optimization opportunities in the callers.
367+ changed.insert (func.get ());
368+ }
369+ }
370+ if (optimize && !changed.empty ()) {
315371 OptUtils::optimizeAfterInlining (changed, module , runner);
316372 }
373+ return !changed.empty ();
317374 }
318375
319376private:
320- void removeParameter (Function* func, Index i, std::vector<Call*> calls) {
377+ std::unordered_map<Call*, Expression**> allDroppedCalls;
378+
379+ void removeParameter (Function* func, Index i, std::vector<Call*>& calls) {
321380 // Clear the type, which is no longer accurate.
322381 func->type = Name ();
323382 // It's cumbersome to adjust local names - TODO don't clear them?
@@ -354,6 +413,45 @@ struct DAE : public Pass {
354413 call->operands .erase (call->operands .begin () + i);
355414 }
356415 }
416+
417+ void removeReturnValue (Function* func, std::vector<Call*>& calls, Module* module ) {
418+ // Clear the type, which is no longer accurate.
419+ func->type = Name ();
420+ func->result = none;
421+ Builder builder (*module );
422+ // Remove any return values.
423+ struct ReturnUpdater : public PostWalker <ReturnUpdater> {
424+ Module* module ;
425+ ReturnUpdater (Function* func, Module* module ) : module (module ) {
426+ walk (func->body );
427+ }
428+ void visitReturn (Return* curr) {
429+ auto * value = curr->value ;
430+ assert (value);
431+ curr->value = nullptr ;
432+ Builder builder (*module );
433+ replaceCurrent (builder.makeSequence (
434+ builder.makeDrop (value),
435+ curr
436+ ));
437+ }
438+ } returnUpdater (func, module );
439+ // Remove any value flowing out.
440+ if (isConcreteType (func->body ->type )) {
441+ func->body = builder.makeDrop (func->body );
442+ }
443+ // Remove the drops on the calls.
444+ for (auto * call : calls) {
445+ auto iter = allDroppedCalls.find (call);
446+ assert (iter != allDroppedCalls.end ());
447+ Expression** location = iter->second ;
448+ *location = call;
449+ // Update the call's type.
450+ if (call->type != unreachable) {
451+ call->type = none;
452+ }
453+ }
454+ }
357455};
358456
359457Pass *createDAEPass () {
0 commit comments