Skip to content

Commit d19ad9a

Browse files
committed
fix: correctly use nested references
1 parent ea2523b commit d19ad9a

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

src/enzyme_ad/jax/Passes/MarkFunctionMemoryEffectsPass.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -257,35 +257,39 @@ struct MarkFunctionMemoryEffectsPass
257257
}
258258
}
259259

260-
void collectAllFunctions(
261-
Operation *op,
262-
DenseMap<SymbolRefAttr, FunctionOpInterface> &symbolToFunc) {
263-
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
264-
// Create the symbol reference for this function
265-
auto symbolRef = SymbolRefAttr::get(funcOp.getOperation());
266-
symbolToFunc[symbolRef] = funcOp;
267-
}
268-
for (Region &region : op->getRegions()) {
269-
for (Block &block : region) {
270-
for (Operation &childOp : block) {
271-
collectAllFunctions(&childOp, symbolToFunc);
272-
}
260+
SymbolRefAttr getFullReference(FunctionOpInterface funcOp) {
261+
SmallVector<StringRef> symbolPath;
262+
auto ctx = funcOp.getOperation()->getContext();
263+
auto op = funcOp.getOperation()->getParentOp();
264+
while (op) {
265+
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
266+
symbolPath.push_back(symbolOp.getName());
273267
}
268+
op = op->getParentOp();
269+
}
270+
if (symbolPath.empty()) {
271+
return SymbolRefAttr::get(funcOp.getOperation());
272+
}
273+
SmallVector<FlatSymbolRefAttr> nestedRefs;
274+
for (int i = 1; i < symbolPath.size(); i++) {
275+
nestedRefs.push_back(FlatSymbolRefAttr::get(ctx, symbolPath[i]));
274276
}
277+
nestedRefs.push_back(FlatSymbolRefAttr::get(ctx, funcOp.getNameAttr()));
278+
return SymbolRefAttr::get(ctx, symbolPath[0], nestedRefs);
275279
}
276280

277281
void runOnOperation() override {
278282
ModuleOp module = getOperation();
279283
auto *ctx = module->getContext();
280284
OpBuilder builder(ctx);
281285

286+
SymbolTableCollection symbolTable;
287+
symbolTable.getSymbolTable(module);
288+
282289
DenseMap<SymbolRefAttr, BitVector> funcEffects;
283290
DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
284291
DenseMap<SymbolRefAttr, FunctionOpInterface> symbolToFunc;
285292

286-
// Collect all functions from the module and nested modules
287-
collectAllFunctions(module, symbolToFunc);
288-
289293
CallGraph callGraph(module);
290294

291295
bool hasCycle = false;
@@ -353,9 +357,10 @@ struct MarkFunctionMemoryEffectsPass
353357
return WalkResult::advance();
354358
});
355359

356-
auto symRef = SymbolRefAttr::get(funcOp.getOperation());
360+
auto symRef = getFullReference(funcOp);
357361
funcEffects[symRef] = std::move(effects);
358362
funcArgEffects[symRef] = std::move(argEffects);
363+
symbolToFunc[symRef] = funcOp;
359364
}
360365

361366
auto propagate = [&](FunctionOpInterface funcOp, BitVector &effects) {
@@ -396,7 +401,7 @@ struct MarkFunctionMemoryEffectsPass
396401
if (!funcOp)
397402
continue;
398403

399-
auto symRef = SymbolRefAttr::get(ctx, funcOp.getName());
404+
auto symRef = getFullReference(funcOp);
400405
analyzeFunctionArgumentMemoryEffects(funcOp, funcArgEffects[symRef],
401406
funcArgEffects);
402407
auto &effects = funcEffects[symRef];
@@ -425,7 +430,7 @@ struct MarkFunctionMemoryEffectsPass
425430
if (!funcOp)
426431
continue;
427432

428-
auto symRef = SymbolRefAttr::get(ctx, funcOp.getName());
433+
auto symRef = getFullReference(funcOp);
429434
analyzeFunctionArgumentMemoryEffects(funcOp, funcArgEffects[symRef],
430435
funcArgEffects);
431436
auto &effects = funcEffects[symRef];
@@ -435,11 +440,7 @@ struct MarkFunctionMemoryEffectsPass
435440

436441
// Finally, attach attributes
437442
for (auto &[symbol, effectsSet] : funcEffects) {
438-
auto it = symbolToFunc.find(symbol);
439-
if (it == symbolToFunc.end())
440-
continue;
441-
auto &funcOp = it->second;
442-
443+
auto funcOp = symbolToFunc[symbol];
443444
auto funcEffectInfo = getEffectInfo(builder, effectsSet);
444445
funcOp->setAttr("enzymexla.memory_effects",
445446
funcEffectInfo.enzymexlaEffects);

test/lit_tests/memoryeffects/nestedmodule.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ module {
1111
return
1212
}
1313
}
14+
module @nested2 {
15+
func.func public @single_dim(%output: memref<3xi64>, %values: memref<3xi64>) {
16+
// CHECK: func.func public @single_dim(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
17+
affine.parallel (%i) = (0) to (3) {
18+
%val = memref.load %values[%i] : memref<3xi64>
19+
affine.store %val, %output[%i] : memref<3xi64>
20+
}
21+
return
22+
}
23+
}
1424
func.func @main(%output: memref<3xi64>, %values: memref<3xi64>) {
1525
// CHECK: func.func @main(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
1626
affine.parallel (%i) = (0) to (3) {

0 commit comments

Comments
 (0)