@@ -257,21 +257,25 @@ struct MarkFunctionMemoryEffectsPass
257257 }
258258 }
259259
260- void collectAllFunctions (
261- Operation *op,
262- DenseMap<SymbolRefAttr, FunctionOpInterface> &symbolToFunc) {
263- if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
264- // Create the symbol reference for this function
265- auto symbolRef = SymbolRefAttr::get (funcOp.getOperation ());
266- symbolToFunc[symbolRef] = funcOp;
267- }
268- for (Region ®ion : op->getRegions ()) {
269- for (Block &block : region) {
270- for (Operation &childOp : block) {
271- collectAllFunctions (&childOp, symbolToFunc);
272- }
260+ SymbolRefAttr getFullReference (FunctionOpInterface funcOp) {
261+ SmallVector<StringRef> symbolPath;
262+ auto ctx = funcOp.getOperation ()->getContext ();
263+ auto op = funcOp.getOperation ()->getParentOp ();
264+ while (op) {
265+ if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
266+ symbolPath.push_back (symbolOp.getName ());
273267 }
268+ op = op->getParentOp ();
269+ }
270+ if (symbolPath.empty ()) {
271+ return SymbolRefAttr::get (funcOp.getOperation ());
272+ }
273+ SmallVector<FlatSymbolRefAttr> nestedRefs;
274+ for (int i = 1 ; i < symbolPath.size (); i++) {
275+ nestedRefs.push_back (FlatSymbolRefAttr::get (ctx, symbolPath[i]));
274276 }
277+ nestedRefs.push_back (FlatSymbolRefAttr::get (ctx, funcOp.getNameAttr ()));
278+ return SymbolRefAttr::get (ctx, symbolPath[0 ], nestedRefs);
275279 }
276280
277281 void runOnOperation () override {
@@ -283,9 +287,6 @@ struct MarkFunctionMemoryEffectsPass
283287 DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
284288 DenseMap<SymbolRefAttr, FunctionOpInterface> symbolToFunc;
285289
286- // Collect all functions from the module and nested modules
287- collectAllFunctions (module , symbolToFunc);
288-
289290 CallGraph callGraph (module );
290291
291292 bool hasCycle = false ;
@@ -353,9 +354,10 @@ struct MarkFunctionMemoryEffectsPass
353354 return WalkResult::advance ();
354355 });
355356
356- auto symRef = SymbolRefAttr::get (funcOp. getOperation () );
357+ auto symRef = getFullReference (funcOp);
357358 funcEffects[symRef] = std::move (effects);
358359 funcArgEffects[symRef] = std::move (argEffects);
360+ symbolToFunc[symRef] = funcOp;
359361 }
360362
361363 auto propagate = [&](FunctionOpInterface funcOp, BitVector &effects) {
@@ -396,7 +398,7 @@ struct MarkFunctionMemoryEffectsPass
396398 if (!funcOp)
397399 continue ;
398400
399- auto symRef = SymbolRefAttr::get (ctx, funcOp. getName () );
401+ auto symRef = getFullReference ( funcOp);
400402 analyzeFunctionArgumentMemoryEffects (funcOp, funcArgEffects[symRef],
401403 funcArgEffects);
402404 auto &effects = funcEffects[symRef];
@@ -425,7 +427,7 @@ struct MarkFunctionMemoryEffectsPass
425427 if (!funcOp)
426428 continue ;
427429
428- auto symRef = SymbolRefAttr::get (ctx, funcOp. getName () );
430+ auto symRef = getFullReference ( funcOp);
429431 analyzeFunctionArgumentMemoryEffects (funcOp, funcArgEffects[symRef],
430432 funcArgEffects);
431433 auto &effects = funcEffects[symRef];
@@ -435,11 +437,7 @@ struct MarkFunctionMemoryEffectsPass
435437
436438 // Finally, attach attributes
437439 for (auto &[symbol, effectsSet] : funcEffects) {
438- auto it = symbolToFunc.find (symbol);
439- if (it == symbolToFunc.end ())
440- continue ;
441- auto &funcOp = it->second ;
442-
440+ auto funcOp = symbolToFunc[symbol];
443441 auto funcEffectInfo = getEffectInfo (builder, effectsSet);
444442 funcOp->setAttr (" enzymexla.memory_effects" ,
445443 funcEffectInfo.enzymexlaEffects );
0 commit comments