Skip to content

Commit 2a018dc

Browse files
committed
test: fix NaN in gradients
The initial conditions for the UDE are the same as the reference solution, leading to taking the derivative of sqrt at 0, which is NaN. Changing the loss function to squared l2loss fixes this.
1 parent 2f71bbf commit 2a018dc

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ OrdinaryDiffEq = "6.74"
2727
Random = "1"
2828
SafeTestsets = "0.1"
2929
SciMLStructures = "1.1.0"
30+
StableRNGs = "1"
3031
SymbolicIndexingInterface = "0.3.15"
3132
Symbolics = "5.27"
3233
Test = "1"
@@ -41,8 +42,9 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
4142
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
4243
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4344
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
45+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4446
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4547
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4648

4749
[targets]
48-
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "SymbolicIndexingInterface"]
50+
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "StableRNGs", "SymbolicIndexingInterface"]

test/lotka_volterra.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using OptimizationOptimisers: Adam
1010
using SciMLStructures
1111
using SciMLStructures: Tunable
1212
using ForwardDiff
13+
using StableRNGs
1314

1415
function lotka_ude()
1516
@variables t x(t)=3.1 y(t)=1.5
@@ -41,7 +42,9 @@ function lotka_true()
4142
end
4243

4344
model = lotka_ude()
44-
nn = create_ude_component(2, 2)
45+
46+
chain = multi_layer_feed_forward(2, 2)
47+
nn = create_ude_component(2, 2; chain, rng = StableRNG(42))
4548

4649
eqs = [connect(model.nn_in, nn.output)
4750
connect(model.nn_out, nn.input)]
@@ -71,7 +74,7 @@ function loss(x, (prob, sol_ref, get_vars, get_refs))
7174
loss = zero(eltype(x))
7275

7376
for i in eachindex(new_sol.u)
74-
loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))))
77+
loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))
7578
end
7679

7780
if SciMLBase.successful_retcode(new_sol)
@@ -106,12 +109,16 @@ op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))
106109
# false
107110
# end
108111

109-
res = solve(op, Adam(), maxiters = 2000)#, callback = plot_cb)
112+
res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
110113

111114
@test res.objective < 1
112115

113116
res_p = SciMLStructures.replace(Tunable(), prob.p, res)
114117
res_prob = remake(prob, p = res_p)
115118
res_sol = solve(res_prob, Rodas4())
116119

120+
# using Plots
121+
# plot(sol_ref, idxs = [model_true.x, model_true.y])
122+
# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])
123+
117124
@test SciMLBase.successful_retcode(res_sol)

0 commit comments

Comments
 (0)