Skip to content

Commit 395a619

Browse files
Merge pull request #79 from SciML/smc/test
fix typos in test
2 parents ba8d973 + 78eac20 commit 395a619

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkitNeuralNets"
22
uuid = "f162e290-f571-43a6-83d9-22ecc16da15f"
33
authors = ["Sebastian Micluța-Câmpeanu <[email protected]> and contributors"]
4-
version = "2.0.0"
4+
version = "2.1.0"
55

66
[deps]
77
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"

test/lotka_volterra.jl

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ using DifferentiationInterface
1515
using SciMLSensitivity
1616
using Zygote: Zygote
1717
using Statistics
18+
using Lux
1819

19-
function lotka_ude()
20+
function lotka_ude(chain)
2021
@variables t x(t)=3.1 y(t)=1.5
2122
@parameters α=1.3 [tunable=false] δ=1.8 [tunable=false]
2223
Dt = ModelingToolkit.D_nounits
2324

24-
chain = multi_layer_feed_forward(2, 2)
2525
@named nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
2626

2727
eqs = [
@@ -36,48 +36,54 @@ end
3636

3737
function lotka_true()
3838
@variables t x(t)=3.1 y(t)=1.5
39-
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8
39+
@parameters α=1.3 [tunable=false] β=0.9 γ=0.8 δ=1.8 [tunable=false]
4040
Dt = ModelingToolkit.D_nounits
4141

4242
eqs = [
4343
Dt(x) ~ α * x - β * x * y,
44-
Dt(y) ~ -δ * y + δ * x * y
44+
Dt(y) ~ -δ * y + γ * x * y
4545
]
4646
return System(eqs, ModelingToolkit.t_nounits, name = :lotka_true)
4747
end
4848

49-
ude_sys = lotka_ude()
49+
rbf(x) = exp.(-(x .^ 2))
5050

51-
sys = mtkcompile(ude_sys, allow_symbolic = true)
51+
chain = multi_layer_feed_forward(2, 2, width = 5, initial_scaling_factor = 1)
52+
ude_sys = lotka_ude(chain)
5253

53-
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0))
54+
sys = mtkcompile(ude_sys)
55+
56+
@test length(equations(sys)) == 2
57+
58+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 5.0))
5459

5560
model_true = mtkcompile(lotka_true())
56-
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0))
57-
sol_ref = solve(prob_true, Vern9(), abstol = 1e-10, reltol = 1e-8)
61+
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 5.0))
62+
sol_ref = solve(prob_true, Vern9(), abstol = 1e-8, reltol = 1e-8)
63+
64+
ts = range(0, 5.0, length = 21)
65+
data = reduce(hcat, sol_ref(ts, idxs = [model_true.x, model_true.y]).u)
5866

5967
x0 = default_values(sys)[sys.nn.p]
6068

6169
get_vars = getu(sys, [sys.x, sys.y])
62-
get_refs = getu(model_true, [model_true.x, model_true.y])
63-
set_x = setp_oop(sys, sys.nn.p)
70+
set_x = setsym_oop(sys, sys.nn.p)
6471

65-
function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
66-
new_p = set_x(prob, x)
67-
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
68-
ts = sol_ref.t
69-
new_sol = solve(new_prob, Vern9(), abstol = 1e-10, reltol = 1e-8, saveat = ts)
72+
function loss(x, (prob, get_vars, data, ts, set_x))
73+
new_u0, new_p = set_x(prob, x)
74+
new_prob = remake(prob, p = new_p, u0 = new_u0)
75+
new_sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8, saveat = ts)
7076

7177
if SciMLBase.successful_retcode(new_sol)
72-
mean(abs2.(reduce(hcat, get_vars(new_sol)) .- reduce(hcat, get_refs(sol_ref))))
78+
mean(abs2.(reduce(hcat, get_vars(new_sol)) .- data))
7379
else
7480
Inf
7581
end
7682
end
7783

7884
of = OptimizationFunction{true}(loss, AutoZygote())
7985

80-
ps = (prob, sol_ref, get_vars, get_refs, set_x);
86+
ps = (prob, get_vars, data, ts, set_x);
8187

8288
@test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
8389
@test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
@@ -89,7 +95,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
8995
@test all(.!isnan.(∇l1))
9096
@test !iszero(∇l1)
9197

92-
@test ∇l1∇l2 rtol=1e-5
98+
@test ∇l1∇l2 rtol=1e-4
9399
@test ∇l1 ∇l3
94100

95101
op = OptimizationProblem(of, x0, ps)
@@ -99,34 +105,36 @@ op = OptimizationProblem(of, x0, ps)
99105
# oh = []
100106

101107
# plot_cb = (opt_state, loss) -> begin
108+
# opt_state.iter % 500 ≠ 0 && return false
102109
# @info "step $(opt_state.iter), loss: $loss"
103110
# push!(oh, opt_state)
104111
# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u)
105112
# new_prob = remake(prob, p = new_p)
106-
# sol = solve(new_prob, Rodas4())
113+
# sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8)
107114
# display(plot(sol))
108115
# false
109116
# end
110117

111-
res = solve(op, Adam(), maxiters = 10000)#, callback = plot_cb)
118+
res = solve(op, Adam(1e-3), maxiters = 25_000)#, callback = plot_cb)
112119

113120
display(res.stats)
114-
@test res.objective < 1
121+
@test res.objective < 1.5e-4
115122

116-
res_p = set_x(prob, res.u)
117-
res_prob = remake(prob, p = res_p)
118-
res_sol = solve(res_prob, Vern9())
123+
u0, p = set_x(prob, res.u)
124+
res_prob = remake(prob; u0, p)
125+
res_sol = solve(res_prob, Vern9(), abstol = 1e-8, reltol = 1e-8, saveat = ts)
126+
127+
@test SciMLBase.successful_retcode(res_sol)
128+
@test mean(abs2.(reduce(hcat, get_vars(res_sol)) .- data)) res.objective
119129

120130
# using Plots
121131
# plot(sol_ref, idxs = [model_true.x, model_true.y])
122132
# plot!(res_sol, idxs = [sys.x, sys.y])
123133

124-
@test SciMLBase.successful_retcode(res_sol)
125-
126134
function lotka_ude2()
127135
@variables t x(t)=3.1 y(t)=1.5 pred(t)[1:2]
128136
@parameters α=1.3 [tunable=false] δ=1.8 [tunable=false]
129-
chain = multi_layer_feed_forward(2, 2)
137+
chain = multi_layer_feed_forward(2, 2; width = 5, initial_scaling_factor = 1)
130138
NN, p = SymbolicNeuralNetwork(; chain, n_input = 2, n_output = 2, rng = StableRNG(42))
131139
Dt = ModelingToolkit.D_nounits
132140

@@ -138,16 +146,16 @@ end
138146

139147
sys2 = mtkcompile(lotka_ude2())
140148

141-
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 1.0))
149+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 5.0))
142150

143151
sol = solve(prob, Vern9(), abstol = 1e-10, reltol = 1e-8)
144152

145153
@test SciMLBase.successful_retcode(sol)
146154

147-
set_x2 = setp_oop(sys2, sys2.p)
148-
ps2 = (prob, sol_ref, get_vars, get_refs, set_x2);
155+
set_x2 = setsym_oop(sys2, sys2.p)
156+
ps2 = (prob, get_vars, data, ts, set_x2);
149157
op2 = OptimizationProblem(of, x0, ps2)
150158

151-
res2 = solve(op2, Adam(), maxiters = 10000)
159+
res2 = solve(op2, Adam(1e-3), maxiters = 25_000)
152160

153161
@test res.u res2.u

0 commit comments

Comments
 (0)