Skip to content

Commit 34ae868

Browse files
committed
switch to GaussAdjoint
1 parent 75a2a1b commit 34ae868

File tree

6 files changed

+7
-9
lines changed

6 files changed

+7
-9
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ icnf = ContinuousNormalizingFlows.construct(
4141
sol_kwargs = (;
4242
save_everystep = false,
4343
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
44-
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
44+
sensealg = SciMLSensitivity.GaussAdjoint(),
4545
),
4646
)
4747

@@ -61,7 +61,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
6161
sol_kwargs = (;
6262
save_everystep = false,
6363
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
64-
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
64+
sensealg = SciMLSensitivity.GaussAdjoint(),
6565
),
6666
)
6767

examples/usage.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ icnf = construct(
4343
sol_kwargs = (;
4444
save_everystep = false,
4545
alg = DefaultODEAlgorithm(),
46-
sensealg = InterpolatingAdjoint(),
46+
sensealg = GaussAdjoint(),
4747
), # pass to the solver
4848
)
4949

test/ci_tests/regression_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Test
2727
sol_kwargs = (;
2828
save_everystep = false,
2929
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
30-
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
30+
sensealg = SciMLSensitivity.GaussAdjoint(),
3131
),
3232
)
3333

test/ci_tests/smoke_tests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be
134134
sol_kwargs = (;
135135
save_everystep = false,
136136
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
137-
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
137+
sensealg = SciMLSensitivity.GaussAdjoint(),
138138
),
139139
)
140140
ps, st = LuxCore.setup(icnf.rng, icnf)
@@ -207,7 +207,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be
207207

208208
Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in
209209
adtypes
210-
211210
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken =
212211
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && (
213212
omode isa ContinuousNormalizingFlows.TrainMode || (

test/ci_tests/speed_tests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be
3232

3333
Test.@testset verbose = true showtiming = true failfast = false "$compute_mode" for compute_mode in
3434
compute_modes
35-
3635
@show compute_mode
3736

3837
rng = StableRNGs.StableRNG(1)
@@ -63,7 +62,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be
6362
sol_kwargs = (;
6463
save_everystep = false,
6564
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
66-
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
65+
sensealg = SciMLSensitivity.GaussAdjoint(),
6766
),
6867
)
6968

test/quality_tests/checkby_JET_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg
117117
sol_kwargs = (;
118118
save_everystep = false,
119119
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
120-
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
120+
sensealg = SciMLSensitivity.GaussAdjoint(),
121121
),
122122
)
123123
ps, st = LuxCore.setup(icnf.rng, icnf)

0 commit comments

Comments
 (0)