Skip to content

Commit 96b8efc

Browse files
jumerckxwsmoses
andauthored
Batched reverse mode (#2216)
* fix for reversemode * fix test * fixup * fixup --------- Co-authored-by: William S. Moses <[email protected]>
1 parent 5b330a9 commit 96b8efc

File tree

10 files changed

+53
-18
lines changed

10 files changed

+53
-18
lines changed

.github/workflows/enzyme-mlir.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
- uses: actions/checkout@v4
3737
with:
3838
repository: 'llvm/llvm-project'
39-
ref: 'eaa7b385368fa7e3dad9b95411d04be55e71494e'
39+
ref: 'ff24e9a19e3db330dd6412aac9d1d6c0b416697f'
4040
path: 'llvm-project'
4141

4242
- name: Get MLIR commit hash

enzyme/Enzyme/MLIR/Passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms
1313
SimplifyMath.cpp
1414
AddToOpToIndexAndLoad.cpp
1515
AddToOpToSplit.cpp
16+
RemovalUtils.cpp
1617
RemoveUnusedEnzymeOps.cpp
1718
SimplifyMemrefCache.cpp
1819
Utils.cpp

enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
3939
pm.getDependentDialects(registry);
4040
}
4141

42-
registry
43-
.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
44-
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect>();
42+
registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
43+
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
44+
mlir::enzyme::EnzymeDialect>();
4545
}
4646

4747
static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,

enzyme/Enzyme/MLIR/Passes/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> {
1818
"complex::ComplexDialect",
1919
"cf::ControlFlowDialect",
2020
"tensor::TensorDialect",
21+
"enzyme::EnzymeDialect",
2122
];
2223
let options = [
2324
Option<

enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other) {
2626
other.initOp->erase();
2727
}
2828

29-
enzyme::PushOp newPushOp = pushOp;
3029
other.pushOp->erase();
3130

3231
enzyme::PopOp newPopOp;

enzyme/Enzyme/MLIR/Passes/RemovalUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct CacheInfo {
4141

4242
Value pushedValue() { return pushOp.getValue(); }
4343
Type cachedType() {
44-
return initOp.getResult().getType().cast<enzyme::CacheType>().getType();
44+
return cast<enzyme::CacheType>(initOp.getResult().getType()).getType();
4545
}
4646

4747
// Pushed values must be the same

enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,18 @@ struct RemoveUnusedEnzymeOpsPass
306306

307307
applyPatterns(op);
308308

309+
bool failed = false;
309310
op->walk([&](FunctionOpInterface func) {
310311
func->walk([&](enzyme::EnzymeOpsRemoverOpInterface iface) {
311-
iface.removeEnzymeOps();
312+
auto result = iface.removeEnzymeOps();
313+
if (!result.succeeded())
314+
failed = true;
312315
});
313316
});
314317

318+
if (failed)
319+
return signalPassFailure();
320+
315321
applyPatterns(op);
316322
}
317323
};

enzyme/test/MLIR/ForwardMode/batched_scalar.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ module {
2121
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
2222
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64>
2323
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64>
24-
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64>
24+
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]]
2525
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
2626
// CHECK-NEXT: }
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math %s | FileCheck %s
2+
3+
module {
4+
func.func @square(%x: f64) -> f64 {
5+
%next = arith.mulf %x, %x : f64
6+
return %next : f64
7+
}
8+
9+
func.func @dsquare(%x: f64, %dr: tensor<2xf64>) -> tensor<2xf64> {
10+
%r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>], width=2 } : (f64, tensor<2xf64>) -> tensor<2xf64>
11+
return %r : tensor<2xf64>
12+
}
13+
}
14+
15+
// CHECK: func.func @dsquare(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
16+
// CHECK-NEXT: %0 = call @diffe2square(%arg0, %arg1) : (f64, tensor<2xf64>) -> tensor<2xf64>
17+
// CHECK-NEXT: return %0 : tensor<2xf64>
18+
// CHECK-NEXT: }
19+
20+
// CHECK: func.func private @diffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
21+
// CHECK-NEXT: %0 = "enzyme.broadcast"(%arg0) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
22+
// CHECK-NEXT: %1 = arith.mulf %arg1, %0 : tensor<2xf64>
23+
// CHECK-NEXT: %2 = "enzyme.broadcast"(%arg0) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
24+
// CHECK-NEXT: %3 = arith.mulf %arg1, %2 : tensor<2xf64>
25+
// CHECK-NEXT: %4 = arith.addf %1, %3 : tensor<2xf64>
26+
// CHECK-NEXT: return %4 : tensor<2xf64>
27+
// CHECK-NEXT: }

enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,16 +277,17 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
277277
if (!vecValue && !startsWith(ord, "local")) {
278278
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) {
279279
os << ")";
280-
if (intrinsic == MLIRDerivatives) {
281-
os << ";\n";
282-
os << "if (gutils->width != 1) {\n"
283-
<< " " << argName << "_" << (idx - 1)
284-
<< " = builder.create<enzyme::BroadcastOp>(\n"
285-
<< " op.getLoc(),\n"
286-
<< " " << argName << "_" << (idx - 1) << ",\n"
287-
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
288-
<< "}";
289-
}
280+
}
281+
if (intrinsic == MLIRDerivatives) {
282+
os << ";\n";
283+
os << curIndent << "if (gutils->width != 1) {\n"
284+
<< curIndent << " " << argName << "_" << (idx - 1)
285+
<< " = builder.create<enzyme::BroadcastOp>(\n"
286+
<< curIndent << " op.getLoc(),\n"
287+
<< curIndent << " " << argName << "_" << (idx - 1) << ",\n"
288+
<< curIndent
289+
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
290+
<< curIndent << "}";
290291
}
291292

292293
if (lookup && intrinsic != MLIRDerivatives)

0 commit comments

Comments
 (0)