Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions src/enzyme_ad/jax/Passes/MarkFunctionMemoryEffectsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ struct MarkFunctionMemoryEffectsPass
auto &memEffects = memEffectsOrNothing.value();

for (const auto &effect : memEffects) {
if (effect.getValue() && effect.getValue() == operand->get()) {
if (!effect.getValue() ||
(effect.getValue() && effect.getValue() == operand->get())) {
if (isa<MemoryEffects::Read>(effect.getEffect())) {
effects.set(0);
} else if (isa<MemoryEffects::Write>(effect.getEffect())) {
Expand Down Expand Up @@ -256,13 +257,35 @@ struct MarkFunctionMemoryEffectsPass
}
}

SymbolRefAttr getFullReference(FunctionOpInterface funcOp) {
SmallVector<StringRef> symbolPath;
auto ctx = funcOp.getOperation()->getContext();
auto op = funcOp.getOperation()->getParentOp();
while (op) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
symbolPath.push_back(symbolOp.getName());
}
op = op->getParentOp();
}
if (symbolPath.empty()) {
return SymbolRefAttr::get(funcOp.getOperation());
}
SmallVector<FlatSymbolRefAttr> nestedRefs;
for (int i = 1; i < symbolPath.size(); i++) {
nestedRefs.push_back(FlatSymbolRefAttr::get(ctx, symbolPath[i]));
}
nestedRefs.push_back(FlatSymbolRefAttr::get(ctx, funcOp.getNameAttr()));
return SymbolRefAttr::get(ctx, symbolPath[0], nestedRefs);
}

void runOnOperation() override {
ModuleOp module = getOperation();
auto *ctx = module->getContext();
OpBuilder builder(ctx);

DenseMap<SymbolRefAttr, BitVector> funcEffects;
DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
DenseMap<SymbolRefAttr, FunctionOpInterface> symbolToFunc;

CallGraph callGraph(module);

Expand Down Expand Up @@ -331,9 +354,10 @@ struct MarkFunctionMemoryEffectsPass
return WalkResult::advance();
});

auto symRef = SymbolRefAttr::get(funcOp.getOperation());
auto symRef = getFullReference(funcOp);
funcEffects[symRef] = std::move(effects);
funcArgEffects[symRef] = std::move(argEffects);
symbolToFunc[symRef] = funcOp;
}

auto propagate = [&](FunctionOpInterface funcOp, BitVector &effects) {
Expand Down Expand Up @@ -374,7 +398,7 @@ struct MarkFunctionMemoryEffectsPass
if (!funcOp)
continue;

auto symRef = SymbolRefAttr::get(ctx, funcOp.getName());
auto symRef = getFullReference(funcOp);
analyzeFunctionArgumentMemoryEffects(funcOp, funcArgEffects[symRef],
funcArgEffects);
auto &effects = funcEffects[symRef];
Expand Down Expand Up @@ -403,7 +427,7 @@ struct MarkFunctionMemoryEffectsPass
if (!funcOp)
continue;

auto symRef = SymbolRefAttr::get(ctx, funcOp.getName());
auto symRef = getFullReference(funcOp);
analyzeFunctionArgumentMemoryEffects(funcOp, funcArgEffects[symRef],
funcArgEffects);
auto &effects = funcEffects[symRef];
Expand All @@ -413,11 +437,7 @@ struct MarkFunctionMemoryEffectsPass

// Finally, attach attributes
for (auto &[symbol, effectsSet] : funcEffects) {
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(
module.lookupSymbol(symbol.getLeafReference()));
if (!funcOp)
continue;

auto funcOp = symbolToFunc[symbol];
auto funcEffectInfo = getEffectInfo(builder, effectsSet);
funcOp->setAttr("enzymexla.memory_effects",
funcEffectInfo.enzymexlaEffects);
Expand All @@ -429,18 +449,20 @@ struct MarkFunctionMemoryEffectsPass
argEffectInfo.enzymexlaEffects);

if (isPointerType(funcOp.getArgument(i))) {
if (argEffectInfo.readOnly) {
if (argEffectInfo.readOnly && !argEffectInfo.readNone) {
assert(!argEffectInfo.writeOnly && "readOnly and writeOnly?");
funcOp.setArgAttr(i, LLVM::LLVMDialect::getReadonlyAttrName(),
builder.getUnitAttr());
}
if (argEffectInfo.writeOnly) {
if (argEffectInfo.writeOnly && !argEffectInfo.readNone) {
assert(!argEffectInfo.readOnly && "writeOnly and readOnly?");
funcOp.setArgAttr(i, LLVM::LLVMDialect::getWriteOnlyAttrName(),
builder.getUnitAttr());
}
// if (argEffectInfo.readNone) {
// funcOp.setArgAttr(i, LLVM::LLVMDialect::getReadnoneAttrName(),
// builder.getUnitAttr());
// }
if (argEffectInfo.readNone) {
funcOp.setArgAttr(i, LLVM::LLVMDialect::getReadnoneAttrName(),
builder.getUnitAttr());
}
if (!argEffects[i][3]) {
funcOp.setArgAttr(i, LLVM::LLVMDialect::getNoFreeAttrName(),
builder.getUnitAttr());
Expand Down
85 changes: 85 additions & 0 deletions test/lit_tests/memoryeffects/add_kernel.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// RUN: enzymexlamlir-opt --mark-func-memory-effects %s | FileCheck %s

module {
// CHECK: llvm.func @add_kernel(%arg0: !llvm.ptr<1> {enzymexla.memory_effects = ["read", "write"], llvm.nofree}, %arg1: !llvm.ptr<1> {enzymexla.memory_effects = ["read", "write"], llvm.nofree}, %arg2: !llvm.ptr<1> {enzymexla.memory_effects = ["read", "write"], llvm.nofree}, %arg3: !llvm.ptr<1> {enzymexla.memory_effects = [], llvm.nofree, llvm.readnone}) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"], noinline = false, nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 32>, ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
llvm.func @add_kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {noinline = false, nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 32>, ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
%0 = llvm.mlir.undef : vector<1xf32>
%1 = llvm.mlir.constant(0 : i32) : i32
%2 = llvm.mlir.constant(32 : i32) : i32
%3 = llvm.mlir.constant(31 : i32) : i32
%4 = llvm.mlir.constant(0 : index) : i32
%5 = llvm.mlir.constant(1024 : i32) : i32
%6 = llvm.mlir.constant(64 : i32) : i32
%7 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32
%8 = llvm.mul %7, %6 : i32
%9 = nvvm.read.ptx.sreg.tid.x : i32
%10 = llvm.and %9, %3 : i32
%11 = llvm.shl %10, %1 : i32
%12 = llvm.or %1, %11 : i32
%13 = llvm.or %12, %1 : i32
%14 = llvm.and %13, %3 : i32
%15 = llvm.lshr %14, %1 : i32
%16 = llvm.xor %1, %15 : i32
%17 = llvm.xor %1, %16 : i32
%18 = llvm.xor %17, %1 : i32
%19 = llvm.xor %17, %2 : i32
%20 = llvm.add %18, %4 : i32
%21 = llvm.add %19, %4 : i32
%22 = llvm.add %8, %20 : i32
%23 = llvm.add %8, %21 : i32
%24 = llvm.icmp "slt" %22, %5 : i32
%25 = llvm.icmp "slt" %23, %5 : i32
%26 = llvm.getelementptr %arg0[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%27 = llvm.getelementptr %arg0[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%28 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %26, %24 : (!llvm.ptr<1>, i1) -> i32
%29 = llvm.bitcast %28 : i32 to vector<1xf32>
%30 = llvm.extractelement %29[%4 : i32] : vector<1xf32>
%31 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %27, %25 : (!llvm.ptr<1>, i1) -> i32
%32 = llvm.bitcast %31 : i32 to vector<1xf32>
%33 = llvm.extractelement %32[%4 : i32] : vector<1xf32>
%34 = llvm.getelementptr %arg1[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%35 = llvm.getelementptr %arg1[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%36 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %34, %24 : (!llvm.ptr<1>, i1) -> i32
%37 = llvm.bitcast %36 : i32 to vector<1xf32>
%38 = llvm.extractelement %37[%4 : i32] : vector<1xf32>
%39 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %35, %25 : (!llvm.ptr<1>, i1) -> i32
%40 = llvm.bitcast %39 : i32 to vector<1xf32>
%41 = llvm.extractelement %40[%4 : i32] : vector<1xf32>
%42 = llvm.fadd %30, %38 : f32
%43 = llvm.fadd %33, %41 : f32
%44 = llvm.getelementptr %arg2[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%45 = llvm.getelementptr %arg2[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%46 = llvm.insertelement %42, %0[%1 : i32] : vector<1xf32>
%47 = llvm.bitcast %46 : vector<1xf32> to i32
%48 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %47, %44, %24 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
%49 = llvm.insertelement %43, %0[%1 : i32] : vector<1xf32>
%50 = llvm.bitcast %49 : vector<1xf32> to i32
%51 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %50, %45, %25 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
llvm.return
}
}

module {
// CHECK: tt.func @add_kernel_call(%arg0: !tt.ptr<f32> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}, %arg1: !tt.ptr<f32> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}, %arg2: !tt.ptr<f32> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}) attributes {enzymexla.memory_effects = ["read", "write"], noinline = false}
tt.func @add_kernel_call(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
%cst = arith.constant dense<1024> : tensor<64xi32>
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%3 = tt.splat %1 : i32 -> tensor<64xi32>
%4 = arith.addi %3, %2 : tensor<64xi32>
%5 = arith.cmpi slt, %4, %cst : tensor<64xi32>
%6 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%7 = tt.addptr %6, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%8 = tt.load %7, %5 : tensor<64x!tt.ptr<f32>>
%9 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%10 = tt.addptr %9, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%11 = tt.load %10, %5 : tensor<64x!tt.ptr<f32>>
%12 = arith.addf %8, %11 : tensor<64xf32>
%13 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%14 = tt.addptr %13, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %14, %12, %5 : tensor<64x!tt.ptr<f32>>
tt.return
}
}
32 changes: 32 additions & 0 deletions test/lit_tests/memoryeffects/nestedmodule.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: enzymexlamlir-opt --mark-func-memory-effects %s | FileCheck %s

module {
module @nested {
func.func public @single_dim(%output: memref<3xi64>, %values: memref<3xi64>) {
// CHECK: func.func public @single_dim(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
affine.parallel (%i) = (0) to (3) {
%val = memref.load %values[%i] : memref<3xi64>
affine.store %val, %output[%i] : memref<3xi64>
}
return
}
}
module @nested2 {
func.func public @single_dim(%output: memref<3xi64>, %values: memref<3xi64>) {
// CHECK: func.func public @single_dim(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
affine.parallel (%i) = (0) to (3) {
%val = memref.load %values[%i] : memref<3xi64>
affine.store %val, %output[%i] : memref<3xi64>
}
return
}
}
func.func @main(%output: memref<3xi64>, %values: memref<3xi64>) {
// CHECK: func.func @main(%arg0: memref<3xi64> {enzymexla.memory_effects = ["write"], llvm.nofree, llvm.writeonly}, %arg1: memref<3xi64> {enzymexla.memory_effects = ["read"], llvm.nofree, llvm.readonly}) attributes {enzymexla.memory_effects = ["read", "write"]}
affine.parallel (%i) = (0) to (3) {
%val = memref.load %values[%i] : memref<3xi64>
affine.store %val, %output[%i] : memref<3xi64>
}
return
}
}
9 changes: 2 additions & 7 deletions test/lit_tests/memoryeffects/shlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ module {
return %0 : tensor<64xi64>
}

// ASSUME: @main4(%arg0: tensor<64xi64> {enzymexla.memory_effects = []}) -> tensor<64xi64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
// NOASSUME: func.func @main4(%arg0: tensor<64xi64> {enzymexla.memory_effects = []}) -> tensor<64xi64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
// ASSUME: @main4(%arg0: tensor<64xi64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor<64xi64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
// NOASSUME: func.func @main4(%arg0: tensor<64xi64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor<64xi64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {

func.func @main5(%arg0: tensor<64xi64>) -> tensor<64xi64> {
%0 = stablehlo.custom_call @mycall1(%arg0) {has_side_effect = false} : (tensor<64xi64>) -> tensor<64xi64>
Expand All @@ -44,9 +44,4 @@ module {

// ASSUME: @main5(%arg0: tensor<64xi64> {enzymexla.memory_effects = []}) -> tensor<64xi64> attributes {enzymexla.memory_effects = []} {
// NOASSUME: func.func @main5(%arg0: tensor<64xi64> {enzymexla.memory_effects = []}) -> tensor<64xi64> attributes {enzymexla.memory_effects = []} {


}



Loading