Skip to content

Commit b531679

Browse files
committed
test via zygote
1 parent 261ec38 commit b531679

File tree

8 files changed

+53
-0
lines changed

8 files changed

+53
-0
lines changed

benchmark/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
55
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
66
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
77
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
8+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
89
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
910
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1011

@@ -15,6 +16,7 @@ ComponentArrays = "0.15"
1516
DifferentiationInterface = "0.6"
1617
Lux = "1"
1718
PkgBenchmark = "0.2"
19+
SciMLSensitivity = "7"
1820
StableRNGs = "1"
1921
Zygote = "0.6, 0.7"
2022
julia = "1.10"

benchmark/benchmarks.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ icnf = ContinuousNormalizingFlows.construct(
3838
steer_rate = 1.0f-1,
3939
λ₃ = 1.0f-2,
4040
rng,
41+
sol_kwargs = (
42+
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
43+
autodiff = true,
44+
autojacvec = SciMLSensitivity.ZygoteVJP(),
45+
),
46+
),
4147
)
4248
ps, st = Lux.setup(icnf.rng, icnf)
4349
ps = ComponentArrays.ComponentArray(ps)
@@ -84,6 +90,12 @@ icnf2 = ContinuousNormalizingFlows.construct(
8490
steer_rate = 1.0f-1,
8591
λ₃ = 1.0f-2,
8692
rng,
93+
sol_kwargs = (
94+
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
95+
autodiff = true,
96+
autojacvec = SciMLSensitivity.ZygoteVJP(),
97+
),
98+
),
8799
)
88100

89101
function diff_loss_tn2(x)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1212
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1313
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1414
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
15+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1516
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1617
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -30,6 +31,7 @@ JET = "0.9"
3031
Lux = "1"
3132
MLJBase = "1"
3233
SciMLBase = "2"
34+
SciMLSensitivity = "7"
3335
StableRNGs = "1"
3436
TerminalLoggers = "0.1"
3537
Zygote = "0.6, 0.7"

test/call_tests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ Test.@testset "Call Tests" begin
150150
resource,
151151
steer_rate = convert(data_type, 1.0e-1),
152152
λ₃ = convert(data_type, 1.0e-2),
153+
sol_kwargs = (
154+
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
155+
autodiff = true,
156+
autojacvec = SciMLSensitivity.ZygoteVJP(),
157+
),
158+
),
153159
),
154160
ContinuousNormalizingFlows.construct(
155161
mt,
@@ -159,6 +165,12 @@ Test.@testset "Call Tests" begin
159165
compute_mode,
160166
inplace,
161167
resource,
168+
sol_kwargs = (
169+
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
170+
autodiff = true,
171+
autojacvec = SciMLSensitivity.ZygoteVJP(),
172+
),
173+
),
162174
),
163175
)
164176
ps, st = Lux.setup(icnf.rng, icnf)

test/fit_tests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ Test.@testset "Fit Tests" begin
145145
resource,
146146
steer_rate = convert(data_type, 1.0e-1),
147147
λ₃ = convert(data_type, 1.0e-2),
148+
sol_kwargs = (
149+
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
150+
autodiff = true,
151+
autojacvec = SciMLSensitivity.ZygoteVJP(),
152+
),
153+
),
148154
),
149155
ContinuousNormalizingFlows.construct(
150156
mt,
@@ -154,6 +160,12 @@ Test.@testset "Fit Tests" begin
154160
compute_mode,
155161
inplace,
156162
resource,
163+
sol_kwargs = (
164+
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
165+
autodiff = true,
166+
autojacvec = SciMLSensitivity.ZygoteVJP(),
167+
),
168+
),
157169
),
158170
)
159171
if mt <: Union{

test/instability_tests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ Test.@testset "Instability" begin
2020
tspan = (0.0f0, 13.0f0),
2121
steer_rate = 1.0f-1,
2222
λ₃ = 1.0f-2,
23+
sol_kwargs = (
24+
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
25+
autodiff = true,
26+
autojacvec = SciMLSensitivity.ZygoteVJP(),
27+
),
28+
),
2329
)
2430
ps, st = Lux.setup(icnf.rng, icnf)
2531
ps = ComponentArrays.ComponentArray(ps)

test/regression_tests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ Test.@testset "Regression Tests" begin
1616
steer_rate = 1.0f-1,
1717
λ₃ = 1.0f-2,
1818
rng,
19+
sol_kwargs = (
20+
sensealg = SciMLSensitivity.InterpolatingAdjoint(;
21+
autodiff = true,
22+
autojacvec = SciMLSensitivity.ZygoteVJP(),
23+
),
24+
),
1925
)
2026
ps, st = Lux.setup(icnf.rng, icnf)
2127
ps = ComponentArrays.ComponentArray(ps)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import ADTypes,
1111
Lux,
1212
MLJBase,
1313
SciMLBase,
14+
SciMLSensitivity,
1415
StableRNGs,
1516
TerminalLoggers,
1617
Test,

0 commit comments

Comments
 (0)