Skip to content

Commit bca35ac

Browse files
committed
update tests
1 parent c7f0d0c commit bca35ac

File tree

2 files changed

+38
-51
lines changed

2 files changed

+38
-51
lines changed

test/lotka_volterra.jl

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using JET
33
using ModelingToolkitNeuralNets
44
using ModelingToolkit
55
using ModelingToolkitStandardLibrary.Blocks
6-
using OrdinaryDiffEq
6+
using OrdinaryDiffEqVerner
77
using SymbolicIndexingInterface
88
using Optimization
99
using OptimizationOptimisers: Adam
@@ -14,22 +14,24 @@ using StableRNGs
1414
using DifferentiationInterface
1515
using SciMLSensitivity
1616
using Zygote: Zygote
17+
using Statistics
1718

1819
function lotka_ude()
1920
@variables t x(t)=3.1 y(t)=1.5
2021
@parameters α=1.3 [tunable = false] δ=1.8 [tunable = false]
2122
Dt = ModelingToolkit.D_nounits
22-
@named nn_in = RealInputArray(nin = 2)
23-
@named nn_out = RealOutputArray(nout = 2)
23+
24+
chain = multi_layer_feed_forward(2, 2)
25+
@named nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
2426

2527
eqs = [
26-
Dt(x) ~ α * x + nn_in.u[1],
27-
Dt(y) ~ -δ * y + nn_in.u[2],
28-
nn_out.u[1] ~ x,
29-
nn_out.u[2] ~ y
28+
Dt(x) ~ α * x + nn.outputs[1],
29+
Dt(y) ~ -δ * y + nn.outputs[2],
30+
nn.inputs[1] ~ x,
31+
nn.inputs[2] ~ y
3032
]
31-
return ODESystem(
32-
eqs, ModelingToolkit.t_nounits, name = :lotka, systems = [nn_in, nn_out])
33+
return System(
34+
eqs, ModelingToolkit.t_nounits, name = :lotka, systems = [nn])
3335
end
3436

3537
function lotka_true()
@@ -41,49 +43,33 @@ function lotka_true()
4143
Dt(x) ~ α * x - β * x * y,
4244
Dt(y) ~ -δ * y + δ * x * y
4345
]
44-
return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka_true)
46+
return System(eqs, ModelingToolkit.t_nounits, name = :lotka_true)
4547
end
4648

47-
model = lotka_ude()
48-
49-
chain = multi_layer_feed_forward(2, 2)
50-
@named nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
51-
52-
eqs = [connect(model.nn_in, nn.output)
53-
connect(model.nn_out, nn.input)]
54-
eqs = [model.nn_in.u ~ nn.output.u, model.nn_out.u ~ nn.input.u]
55-
ude_sys = complete(ODESystem(
56-
eqs, ModelingToolkit.t_nounits, systems = [model, nn],
57-
name = :ude_sys))
49+
ude_sys = lotka_ude()
5850

59-
sys = structural_simplify(ude_sys, allow_symbolic = true)
51+
sys = mtkcompile(ude_sys, allow_symbolic = true)
6052

61-
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])
53+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0))
6254

63-
model_true = structural_simplify(lotka_true())
64-
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), [])
65-
sol_ref = solve(prob_true, Rodas5P(), abstol = 1e-10, reltol = 1e-8)
55+
model_true = mtkcompile(lotka_true())
56+
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0))
57+
sol_ref = solve(prob_true, Vern9(), abstol = 1e-10, reltol = 1e-8)
6658

67-
x0 = default_values(sys)[nn.p]
59+
x0 = default_values(sys)[sys.nn.p]
6860

69-
get_vars = getu(sys, [sys.lotka.x, sys.lotka.y])
61+
get_vars = getu(sys, [sys.x, sys.y])
7062
get_refs = getu(model_true, [model_true.x, model_true.y])
71-
set_x = setp_oop(sys, nn.p)
63+
set_x = setp_oop(sys, sys.nn.p)
7264

7365
function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
7466
new_p = set_x(prob, x)
7567
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
7668
ts = sol_ref.t
77-
new_sol = solve(new_prob, Rodas5P(), abstol = 1e-10, reltol = 1e-8, saveat = ts)
78-
79-
loss = zero(eltype(x))
80-
81-
for i in eachindex(new_sol.u)
82-
loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))
83-
end
69+
new_sol = solve(new_prob, Vern9(), abstol = 1e-10, reltol = 1e-8, saveat = ts)
8470

8571
if SciMLBase.successful_retcode(new_sol)
86-
loss
72+
mean(abs2.(reduce(hcat, get_vars(new_sol)) .- reduce(hcat, get_refs(sol_ref))))
8773
else
8874
Inf
8975
end
@@ -103,8 +89,8 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
10389
@test all(.!isnan.(∇l1))
10490
@test !iszero(∇l1)
10591

106-
@test ∇l1∇l2 rtol=1e-3
107-
@test ∇l1∇l3 rtol=1e-5
92+
@test ∇l1∇l2 rtol=1e-5
93+
@test ∇l1 ∇l3
10894

10995
op = OptimizationProblem(of, x0, ps)
11096

@@ -124,15 +110,16 @@ op = OptimizationProblem(of, x0, ps)
124110

125111
res = solve(op, Adam(), maxiters = 10000)#, callback = plot_cb)
126112

113+
display(res.stats)
127114
@test res.objective < 1
128115

129116
res_p = set_x(prob, res.u)
130117
res_prob = remake(prob, p = res_p)
131-
res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
118+
res_sol = solve(res_prob, Vern9())
132119

133120
# using Plots
134121
# plot(sol_ref, idxs = [model_true.x, model_true.y])
135-
# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])
122+
# plot!(res_sol, idxs = [sys.x, sys.y])
136123

137124
@test SciMLBase.successful_retcode(res_sol)
138125

@@ -146,14 +133,14 @@ function lotka_ude2()
146133
eqs = [pred ~ NN([x, y], p)
147134
Dt(x) ~ α * x + pred[1]
148135
Dt(y) ~ -δ * y + pred[2]]
149-
return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka)
136+
return System(eqs, ModelingToolkit.t_nounits, name = :lotka)
150137
end
151138

152-
sys2 = structural_simplify(lotka_ude2())
139+
sys2 = mtkcompile(lotka_ude2())
153140

154-
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 1.0), [])
141+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 1.0))
155142

156-
sol = solve(prob, Rodas5P(), abstol = 1e-10, reltol = 1e-8)
143+
sol = solve(prob, Vern9(), abstol = 1e-10, reltol = 1e-8)
157144

158145
@test SciMLBase.successful_retcode(sol)
159146

test/macro.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ModelingToolkit, Symbolics
22
using ModelingToolkit: t_nounits as t, D_nounits as D
3-
using OrdinaryDiffEq
3+
using OrdinaryDiffEqVerner
44
using ModelingToolkitNeuralNets
55
using ModelingToolkitStandardLibrary.Blocks
66
using Lux
@@ -28,14 +28,14 @@ end
2828
nn = NeuralNetworkBlock(n_input = 1, n_output = 1)
2929
end
3030
@equations begin
31-
connect(friction_ude.nn_in, nn.output)
32-
connect(friction_ude.nn_out, nn.input)
31+
connect(friction_ude.nn_in.u, nn.outputs)
32+
connect(friction_ude.nn_out.u, nn.inputs)
3333
end
3434
end
3535

36-
@mtkbuild sys = TestFriction_UDE()
36+
@mtkcompile sys = TestFriction_UDE()
3737

38-
prob = ODEProblem(sys, [], (0, 1.0), [])
39-
sol = solve(prob, Rodas4())
38+
prob = ODEProblem(sys, [], (0, 1.0))
39+
sol = solve(prob, Vern9())
4040

4141
@test SciMLBase.successful_retcode(sol)

0 commit comments

Comments
 (0)