@@ -256,13 +256,34 @@ struct MarkFunctionMemoryEffectsPass
256256 }
257257 }
258258
259+ void collectAllFunctions (
260+ Operation *op,
261+ DenseMap<SymbolRefAttr, FunctionOpInterface> &symbolToFunc) {
262+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
263+ // Create the symbol reference for this function
264+ auto symbolRef = SymbolRefAttr::get (funcOp.getOperation ());
265+ symbolToFunc[symbolRef] = funcOp;
266+ }
267+ for (Region ®ion : op->getRegions ()) {
268+ for (Block &block : region) {
269+ for (Operation &childOp : block) {
270+ collectAllFunctions (&childOp, symbolToFunc);
271+ }
272+ }
273+ }
274+ }
275+
259276 void runOnOperation () override {
260277 ModuleOp module = getOperation ();
261278 auto *ctx = module ->getContext ();
262279 OpBuilder builder (ctx);
263280
264281 DenseMap<SymbolRefAttr, BitVector> funcEffects;
265282 DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
283+ DenseMap<SymbolRefAttr, FunctionOpInterface> symbolToFunc;
284+
285+ // Collect all functions from the module and nested modules
286+ collectAllFunctions (module , symbolToFunc);
266287
267288 CallGraph callGraph (module );
268289
@@ -413,10 +434,10 @@ struct MarkFunctionMemoryEffectsPass
413434
414435 // Finally, attach attributes
415436 for (auto &[symbol, effectsSet] : funcEffects) {
416- auto funcOp = dyn_cast_or_null<FunctionOpInterface>(
417- module .lookupSymbol (symbol.getLeafReference ()));
418- if (!funcOp)
437+ auto it = symbolToFunc.find (symbol);
438+ if (it == symbolToFunc.end ())
419439 continue ;
440+ auto &funcOp = it->second ;
420441
421442 auto funcEffectInfo = getEffectInfo (builder, effectsSet);
422443 funcOp->setAttr (" enzymexla.memory_effects" ,
0 commit comments