Skip to content

Commit 05814a8

Browse files
committed
fix: mark memory effects for nested modules correctly
1 parent 9b689ed commit 05814a8

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

src/enzyme_ad/jax/Passes/MarkFunctionMemoryEffectsPass.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,34 @@ struct MarkFunctionMemoryEffectsPass
256256
}
257257
}
258258

259+
void collectAllFunctions(
260+
Operation *op,
261+
DenseMap<SymbolRefAttr, FunctionOpInterface> &symbolToFunc) {
262+
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
263+
// Create the symbol reference for this function
264+
auto symbolRef = SymbolRefAttr::get(funcOp.getOperation());
265+
symbolToFunc[symbolRef] = funcOp;
266+
}
267+
for (Region &region : op->getRegions()) {
268+
for (Block &block : region) {
269+
for (Operation &childOp : block) {
270+
collectAllFunctions(&childOp, symbolToFunc);
271+
}
272+
}
273+
}
274+
}
275+
259276
void runOnOperation() override {
260277
ModuleOp module = getOperation();
261278
auto *ctx = module->getContext();
262279
OpBuilder builder(ctx);
263280

264281
DenseMap<SymbolRefAttr, BitVector> funcEffects;
265282
DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
283+
DenseMap<SymbolRefAttr, FunctionOpInterface> symbolToFunc;
284+
285+
// Collect all functions from the module and nested modules
286+
collectAllFunctions(module, symbolToFunc);
266287

267288
CallGraph callGraph(module);
268289

@@ -413,10 +434,10 @@ struct MarkFunctionMemoryEffectsPass
413434

414435
// Finally, attach attributes
415436
for (auto &[symbol, effectsSet] : funcEffects) {
416-
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(
417-
module.lookupSymbol(symbol.getLeafReference()));
418-
if (!funcOp)
437+
auto it = symbolToFunc.find(symbol);
438+
if (it == symbolToFunc.end())
419439
continue;
440+
auto &funcOp = it->second;
420441

421442
auto funcEffectInfo = getEffectInfo(builder, effectsSet);
422443
funcOp->setAttr("enzymexla.memory_effects",
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: enzymexlamlir-opt --mark-func-memory-effects %s | FileCheck %s
2+
3+
module {
4+
module @nested {
5+
func.func public @single_dim(%output: memref<3xi64>, %values: memref<3xi64>) {
6+
// CHECK: func.func public @single_dim(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
7+
affine.parallel (%i) = (0) to (3) {
8+
%val = memref.load %values[%i] : memref<3xi64>
9+
affine.store %val, %output[%i] : memref<3xi64>
10+
}
11+
return
12+
}
13+
}
14+
func.func @main(%output: memref<3xi64>, %values: memref<3xi64>) {
15+
// CHECK: func.func @main(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
16+
affine.parallel (%i) = (0) to (3) {
17+
%val = memref.load %values[%i] : memref<3xi64>
18+
affine.store %val, %output[%i] : memref<3xi64>
19+
}
20+
return
21+
}
22+
}

0 commit comments

Comments
 (0)