@@ -257,35 +257,39 @@ struct MarkFunctionMemoryEffectsPass
257257 }
258258 }
259259
260- void collectAllFunctions (
261- Operation *op,
262- DenseMap<SymbolRefAttr, FunctionOpInterface> &symbolToFunc) {
263- if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
264- // Create the symbol reference for this function
265- auto symbolRef = SymbolRefAttr::get (funcOp.getOperation ());
266- symbolToFunc[symbolRef] = funcOp;
267- }
268- for (Region ®ion : op->getRegions ()) {
269- for (Block &block : region) {
270- for (Operation &childOp : block) {
271- collectAllFunctions (&childOp, symbolToFunc);
272- }
260+ SymbolRefAttr getFullReference (FunctionOpInterface funcOp) {
261+ SmallVector<StringRef> symbolPath;
262+ auto ctx = funcOp.getOperation ()->getContext ();
263+ auto op = funcOp.getOperation ()->getParentOp ();
264+ while (op) {
265+ if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
266+ symbolPath.push_back (symbolOp.getName ());
273267 }
268+ op = op->getParentOp ();
269+ }
270+ if (symbolPath.empty ()) {
271+ return SymbolRefAttr::get (funcOp.getOperation ());
272+ }
273+ SmallVector<FlatSymbolRefAttr> nestedRefs;
274+ for (int i = 1 ; i < symbolPath.size (); i++) {
275+ nestedRefs.push_back (FlatSymbolRefAttr::get (ctx, symbolPath[i]));
274276 }
277+ nestedRefs.push_back (FlatSymbolRefAttr::get (ctx, funcOp.getNameAttr ()));
278+ return SymbolRefAttr::get (ctx, symbolPath[0 ], nestedRefs);
275279 }
276280
277281 void runOnOperation () override {
278282 ModuleOp module = getOperation ();
279283 auto *ctx = module ->getContext ();
280284 OpBuilder builder (ctx);
281285
286+ SymbolTableCollection symbolTable;
287+ symbolTable.getSymbolTable (module );
288+
282289 DenseMap<SymbolRefAttr, BitVector> funcEffects;
283290 DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
284291 DenseMap<SymbolRefAttr, FunctionOpInterface> symbolToFunc;
285292
286- // Collect all functions from the module and nested modules
287- collectAllFunctions (module , symbolToFunc);
288-
289293 CallGraph callGraph (module );
290294
291295 bool hasCycle = false ;
@@ -353,9 +357,10 @@ struct MarkFunctionMemoryEffectsPass
353357 return WalkResult::advance ();
354358 });
355359
356- auto symRef = SymbolRefAttr::get (funcOp. getOperation () );
360+ auto symRef = getFullReference (funcOp);
357361 funcEffects[symRef] = std::move (effects);
358362 funcArgEffects[symRef] = std::move (argEffects);
363+ symbolToFunc[symRef] = funcOp;
359364 }
360365
361366 auto propagate = [&](FunctionOpInterface funcOp, BitVector &effects) {
@@ -396,7 +401,7 @@ struct MarkFunctionMemoryEffectsPass
396401 if (!funcOp)
397402 continue ;
398403
399- auto symRef = SymbolRefAttr::get (ctx, funcOp. getName () );
404+ auto symRef = getFullReference ( funcOp);
400405 analyzeFunctionArgumentMemoryEffects (funcOp, funcArgEffects[symRef],
401406 funcArgEffects);
402407 auto &effects = funcEffects[symRef];
@@ -425,7 +430,7 @@ struct MarkFunctionMemoryEffectsPass
425430 if (!funcOp)
426431 continue ;
427432
428- auto symRef = SymbolRefAttr::get (ctx, funcOp. getName () );
433+ auto symRef = getFullReference ( funcOp);
429434 analyzeFunctionArgumentMemoryEffects (funcOp, funcArgEffects[symRef],
430435 funcArgEffects);
431436 auto &effects = funcEffects[symRef];
@@ -435,11 +440,7 @@ struct MarkFunctionMemoryEffectsPass
435440
436441 // Finally, attach attributes
437442 for (auto &[symbol, effectsSet] : funcEffects) {
438- auto it = symbolToFunc.find (symbol);
439- if (it == symbolToFunc.end ())
440- continue ;
441- auto &funcOp = it->second ;
442-
443+ auto funcOp = symbolToFunc[symbol];
443444 auto funcEffectInfo = getEffectInfo (builder, effectsSet);
444445 funcOp->setAttr (" enzymexla.memory_effects" ,
445446 funcEffectInfo.enzymexlaEffects );
0 commit comments