@@ -71,6 +71,14 @@ LogicalResult
71
71
mlir::bufferization::dropEquivalentBufferResults (ModuleOp module ) {
72
72
IRRewriter rewriter (module .getContext ());
73
73
74
+ DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
75
+ // Collect the mapping of functions to their call sites.
76
+ module .walk ([&](func::CallOp callOp) {
77
+ if (func::FuncOp calledFunc = getCalledFunction (callOp)) {
78
+ callerMap[calledFunc].insert (callOp);
79
+ }
80
+ });
81
+
74
82
for (auto funcOp : module .getOps <func::FuncOp>()) {
75
83
if (funcOp.isExternal ())
76
84
continue ;
@@ -109,10 +117,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
109
117
returnOp.getOperandsMutable ().assign (newReturnValues);
110
118
111
119
// Update function calls.
112
- module .walk ([&](func::CallOp callOp) {
113
- if (getCalledFunction (callOp) != funcOp)
114
- return WalkResult::skip ();
115
-
120
+ for (func::CallOp callOp : callerMap[funcOp]) {
116
121
rewriter.setInsertionPoint (callOp);
117
122
auto newCallOp = rewriter.create <func::CallOp>(callOp.getLoc (), funcOp,
118
123
callOp.getOperands ());
@@ -136,8 +141,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
136
141
newResults.push_back (replacement);
137
142
}
138
143
rewriter.replaceOp (callOp, newResults);
139
- return WalkResult::advance ();
140
- });
144
+ }
141
145
}
142
146
143
147
return success ();
0 commit comments