Skip to content

Commit 51f6885

Browse files
committed
fix tests
1 parent d679213 commit 51f6885

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

test/lotka_volterra.jl

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,18 @@ end
4848

4949
rbf(x) = exp.(-(x .^ 2))
5050

51-
chain = Lux.Chain(
52-
Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
53-
Lux.Dense(5, 2))
51+
chain = multi_layer_feed_forward(2, 2, width = 5, initial_scaling_factor = 1)
5452
ude_sys = lotka_ude(chain)
5553

56-
sys = mtkcompile(ude_sys, allow_symbolic = true)
54+
sys = mtkcompile(ude_sys)
55+
56+
@test length(equations(sys)) == 2
5757

5858
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 5.0))
5959

6060
model_true = mtkcompile(lotka_true())
6161
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 5.0))
62-
sol_ref = solve(prob_true, Vern9(), abstol = 1e-12, reltol = 1e-12)
62+
sol_ref = solve(prob_true, Vern9(), abstol = 1e-8, reltol = 1e-8)
6363

6464
ts = range(0, 5.0, length = 21)
6565
data = reduce(hcat, sol_ref(ts, idxs = [model_true.x, model_true.y]).u)
@@ -69,11 +69,10 @@ x0 = default_values(sys)[sys.nn.p]
6969
get_vars = getu(sys, [sys.x, sys.y])
7070
set_x = setsym_oop(sys, sys.nn.p)
7171

72-
function loss(x, (prob, sol_ref, get_vars, data, ts, set_x))
73-
# new_u0, new_p = set_x(prob, 1, x)
72+
function loss(x, (prob, get_vars, data, ts, set_x))
7473
new_u0, new_p = set_x(prob, x)
7574
new_prob = remake(prob, p = new_p, u0 = new_u0)
76-
new_sol = solve(new_prob, Vern9(), abstol = 1e-10, reltol = 1e-8, saveat = ts)
75+
new_sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8, saveat = ts)
7776

7877
if SciMLBase.successful_retcode(new_sol)
7978
mean(abs2.(reduce(hcat, get_vars(new_sol)) .- data))
@@ -84,7 +83,7 @@ end
8483

8584
of = OptimizationFunction{true}(loss, AutoZygote())
8685

87-
ps = (prob, sol_ref, get_vars, data, ts, set_x);
86+
ps = (prob, get_vars, data, ts, set_x);
8887

8988
@test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9089
@test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
@@ -106,34 +105,36 @@ op = OptimizationProblem(of, x0, ps)
106105
# oh = []
107106

108107
# plot_cb = (opt_state, loss) -> begin
108+
# opt_state.iter % 500 ≠ 0 && return false
109109
# @info "step $(opt_state.iter), loss: $loss"
110110
# push!(oh, opt_state)
111111
# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u)
112112
# new_prob = remake(prob, p = new_p)
113-
# sol = solve(new_prob, Rodas4())
113+
# sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8)
114114
# display(plot(sol))
115115
# false
116116
# end
117117

118-
res = solve(op, Adam(), maxiters = 10000)#, callback = plot_cb)
118+
res = solve(op, Adam(1e-3), maxiters = 25_000)#, callback = plot_cb)
119119

120120
display(res.stats)
121-
@test res.objective < 1
121+
@test res.objective < 1.5e-4
122+
123+
u0, p = set_x(prob, res.u)
124+
res_prob = remake(prob; u0, p)
125+
res_sol = solve(res_prob, Vern9(), abstol = 1e-8, reltol = 1e-8, saveat = ts)
122126

123-
res_p = set_x(prob, res.u)
124-
res_prob = remake(prob, p = res_p)
125-
res_sol = solve(res_prob, Vern9())
127+
@test SciMLBase.successful_retcode(res_sol)
128+
@test mean(abs2.(reduce(hcat, get_vars(res_sol)) .- data)) res.objective
126129

127130
# using Plots
128131
# plot(sol_ref, idxs = [model_true.x, model_true.y])
129132
# plot!(res_sol, idxs = [sys.x, sys.y])
130133

131-
@test SciMLBase.successful_retcode(res_sol)
132-
133134
function lotka_ude2()
134135
@variables t x(t)=3.1 y(t)=1.5 pred(t)[1:2]
135136
@parameters α=1.3 [tunable = false] δ=1.8 [tunable = false]
136-
chain = multi_layer_feed_forward(2, 2)
137+
chain = multi_layer_feed_forward(2, 2; width = 5, initial_scaling_factor = 1)
137138
NN, p = SymbolicNeuralNetwork(; chain, n_input = 2, n_output = 2, rng = StableRNG(42))
138139
Dt = ModelingToolkit.D_nounits
139140

@@ -145,16 +146,16 @@ end
145146

146147
sys2 = mtkcompile(lotka_ude2())
147148

148-
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 1.0))
149+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 5.0))
149150

150151
sol = solve(prob, Vern9(), abstol = 1e-10, reltol = 1e-8)
151152

152153
@test SciMLBase.successful_retcode(sol)
153154

154-
set_x2 = setp_oop(sys2, sys2.p)
155-
ps2 = (prob, sol_ref, get_vars, get_refs, set_x2);
155+
set_x2 = setsym_oop(sys2, sys2.p)
156+
ps2 = (prob, get_vars, data, ts, set_x2);
156157
op2 = OptimizationProblem(of, x0, ps2)
157158

158-
res2 = solve(op2, Adam(), maxiters = 10000)
159+
res2 = solve(op2, Adam(1e-3), maxiters = 25_000)
159160

160161
@test res.u res2.u

0 commit comments

Comments
 (0)