Skip to content

Commit 8f3c5fd

Browse files
committed
fix: correctly use nested references
1 parent ea2523b commit 8f3c5fd

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

src/enzyme_ad/jax/Passes/MarkFunctionMemoryEffectsPass.cpp

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -257,21 +257,25 @@ struct MarkFunctionMemoryEffectsPass
257257
}
258258
}
259259

260-
void collectAllFunctions(
261-
Operation *op,
262-
DenseMap<SymbolRefAttr, FunctionOpInterface> &symbolToFunc) {
263-
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
264-
// Create the symbol reference for this function
265-
auto symbolRef = SymbolRefAttr::get(funcOp.getOperation());
266-
symbolToFunc[symbolRef] = funcOp;
267-
}
268-
for (Region &region : op->getRegions()) {
269-
for (Block &block : region) {
270-
for (Operation &childOp : block) {
271-
collectAllFunctions(&childOp, symbolToFunc);
272-
}
260+
SymbolRefAttr getFullReference(FunctionOpInterface funcOp) {
261+
SmallVector<StringRef> symbolPath;
262+
auto ctx = funcOp.getOperation()->getContext();
263+
auto op = funcOp.getOperation()->getParentOp();
264+
while (op) {
265+
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
266+
symbolPath.push_back(symbolOp.getName());
273267
}
268+
op = op->getParentOp();
269+
}
270+
if (symbolPath.empty()) {
271+
return SymbolRefAttr::get(funcOp.getOperation());
272+
}
273+
SmallVector<FlatSymbolRefAttr> nestedRefs;
274+
for (int i = 1; i < symbolPath.size(); i++) {
275+
nestedRefs.push_back(FlatSymbolRefAttr::get(ctx, symbolPath[i]));
274276
}
277+
nestedRefs.push_back(FlatSymbolRefAttr::get(ctx, funcOp.getNameAttr()));
278+
return SymbolRefAttr::get(ctx, symbolPath[0], nestedRefs);
275279
}
276280

277281
void runOnOperation() override {
@@ -283,9 +287,6 @@ struct MarkFunctionMemoryEffectsPass
283287
DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
284288
DenseMap<SymbolRefAttr, FunctionOpInterface> symbolToFunc;
285289

286-
// Collect all functions from the module and nested modules
287-
collectAllFunctions(module, symbolToFunc);
288-
289290
CallGraph callGraph(module);
290291

291292
bool hasCycle = false;
@@ -353,9 +354,10 @@ struct MarkFunctionMemoryEffectsPass
353354
return WalkResult::advance();
354355
});
355356

356-
auto symRef = SymbolRefAttr::get(funcOp.getOperation());
357+
auto symRef = getFullReference(funcOp);
357358
funcEffects[symRef] = std::move(effects);
358359
funcArgEffects[symRef] = std::move(argEffects);
360+
symbolToFunc[symRef] = funcOp;
359361
}
360362

361363
auto propagate = [&](FunctionOpInterface funcOp, BitVector &effects) {
@@ -396,7 +398,7 @@ struct MarkFunctionMemoryEffectsPass
396398
if (!funcOp)
397399
continue;
398400

399-
auto symRef = SymbolRefAttr::get(ctx, funcOp.getName());
401+
auto symRef = getFullReference(funcOp);
400402
analyzeFunctionArgumentMemoryEffects(funcOp, funcArgEffects[symRef],
401403
funcArgEffects);
402404
auto &effects = funcEffects[symRef];
@@ -425,7 +427,7 @@ struct MarkFunctionMemoryEffectsPass
425427
if (!funcOp)
426428
continue;
427429

428-
auto symRef = SymbolRefAttr::get(ctx, funcOp.getName());
430+
auto symRef = getFullReference(funcOp);
429431
analyzeFunctionArgumentMemoryEffects(funcOp, funcArgEffects[symRef],
430432
funcArgEffects);
431433
auto &effects = funcEffects[symRef];
@@ -435,11 +437,7 @@ struct MarkFunctionMemoryEffectsPass
435437

436438
// Finally, attach attributes
437439
for (auto &[symbol, effectsSet] : funcEffects) {
438-
auto it = symbolToFunc.find(symbol);
439-
if (it == symbolToFunc.end())
440-
continue;
441-
auto &funcOp = it->second;
442-
440+
auto funcOp = symbolToFunc[symbol];
443441
auto funcEffectInfo = getEffectInfo(builder, effectsSet);
444442
funcOp->setAttr("enzymexla.memory_effects",
445443
funcEffectInfo.enzymexlaEffects);

test/lit_tests/memoryeffects/nestedmodule.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ module {
1111
return
1212
}
1313
}
14+
module @nested2 {
15+
func.func public @single_dim(%output: memref<3xi64>, %values: memref<3xi64>) {
16+
// CHECK: func.func public @single_dim(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
17+
affine.parallel (%i) = (0) to (3) {
18+
%val = memref.load %values[%i] : memref<3xi64>
19+
affine.store %val, %output[%i] : memref<3xi64>
20+
}
21+
return
22+
}
23+
}
1424
func.func @main(%output: memref<3xi64>, %values: memref<3xi64>) {
1525
// CHECK: func.func @main(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
1626
affine.parallel (%i) = (0) to (3) {

0 commit comments

Comments
 (0)