Skip to content

Commit cad7d0a

Browse files
committed
test enzyme with reverse mode
1 parent 5b1c25e commit cad7d0a

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ icnf = construct(
5050
nn,
5151
nvars, # number of variables
5252
naugs; # number of augmented dimensions
53-
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
53+
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
5454
# inplace = true, # use the inplace version of functions
5555
# resource = CUDALibs(), # process data by GPU
5656
tspan = (0.0f0, 13.0f0), # have bigger time span

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ icnf = ContinuousNormalizingFlows.construct(
3535
naugs;
3636
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
3737
ADTypes.AutoEnzyme(;
38-
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
38+
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
3939
function_annotation = Enzyme.Const,
4040
),
4141
),
@@ -94,7 +94,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
9494
inplace = true,
9595
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
9696
ADTypes.AutoEnzyme(;
97-
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
97+
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
9898
function_annotation = Enzyme.Const,
9999
),
100100
),

src/base_icnf.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function construct(
66
data_type::Type{<:AbstractFloat} = Float32,
77
compute_mode::ComputeMode = DIJacVecMatrixMode(
88
ADTypes.AutoEnzyme(;
9-
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
9+
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
1010
function_annotation = Enzyme.Const,
1111
),
1212
),

test/regression_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Test.@testset "Regression Tests" begin
1313
naugs;
1414
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
1515
ADTypes.AutoEnzyme(;
16-
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
16+
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
1717
function_annotation = Enzyme.Const,
1818
),
1919
),

0 commit comments

Comments
 (0)