Skip to content

Commit 722f633

Browse files
committed
go by zygote
1 parent 4a7e996 commit 722f633

File tree

5 files changed

+10
-1
lines changed

5 files changed

+10
-1
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ icnf = ContinuousNormalizingFlows.construct(
5252
alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()),
5353
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
5454
autodiff = true,
55+
autojacvec = ZygoteVJP(),
5556
checkpointing = true,
5657
),
5758
),
@@ -124,6 +125,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
124125
alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()),
125126
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
126127
autodiff = true,
128+
autojacvec = ZygoteVJP(),
127129
checkpointing = true,
128130
),
129131
),

examples/usage.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ icnf = construct(
4242
abstol = eps(one(Float32)),
4343
maxiters = typemax(Int),
4444
alg = VCABM(; thread = True()),
45-
sensealg = InterpolatingAdjoint(; autodiff = true, checkpointing = true),
45+
sensealg = InterpolatingAdjoint(;
46+
autodiff = true,
47+
autojacvec = ZygoteVJP(),
48+
checkpointing = true,
49+
),
4650
), # pass to the solver
4751
)
4852

test/checkby_JET_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Test.@testset "CheckByJET" begin
3131
alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()),
3232
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
3333
autodiff = true,
34+
autojacvec = ZygoteVJP(),
3435
checkpointing = true,
3536
),
3637
),

test/regression_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Test.@testset "Regression Tests" begin
2727
alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()),
2828
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
2929
autodiff = true,
30+
autojacvec = ZygoteVJP(),
3031
checkpointing = true,
3132
),
3233
),

test/smoke_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ Test.@testset "Smoke Tests" begin
130130
alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()),
131131
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
132132
autodiff = true,
133+
autojacvec = ZygoteVJP(),
133134
checkpointing = true,
134135
),
135136
),

0 commit comments

Comments
 (0)