@@ -3,7 +3,7 @@ using JET
3
3
using ModelingToolkitNeuralNets
4
4
using ModelingToolkit
5
5
using ModelingToolkitStandardLibrary. Blocks
6
- using OrdinaryDiffEq
6
+ using OrdinaryDiffEqVerner
7
7
using SymbolicIndexingInterface
8
8
using Optimization
9
9
using OptimizationOptimisers: Adam
@@ -14,22 +14,24 @@ using StableRNGs
14
14
using DifferentiationInterface
15
15
using SciMLSensitivity
16
16
using Zygote: Zygote
17
+ using Statistics
17
18
18
19
function lotka_ude ()
19
20
@variables t x (t)= 3.1 y (t)= 1.5
20
21
@parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
21
22
Dt = ModelingToolkit. D_nounits
22
- @named nn_in = RealInputArray (nin = 2 )
23
- @named nn_out = RealOutputArray (nout = 2 )
23
+
24
+ chain = multi_layer_feed_forward (2 , 2 )
25
+ @named nn = NeuralNetworkBlock (2 , 2 ; chain, rng = StableRNG (42 ))
24
26
25
27
eqs = [
26
- Dt (x) ~ α * x + nn_in . u [1 ],
27
- Dt (y) ~ - δ * y + nn_in . u [2 ],
28
- nn_out . u [1 ] ~ x,
29
- nn_out . u [2 ] ~ y
28
+ Dt (x) ~ α * x + nn . outputs [1 ],
29
+ Dt (y) ~ - δ * y + nn . outputs [2 ],
30
+ nn . inputs [1 ] ~ x,
31
+ nn . inputs [2 ] ~ y
30
32
]
31
- return ODESystem (
32
- eqs, ModelingToolkit. t_nounits, name = :lotka , systems = [nn_in, nn_out ])
33
+ return System (
34
+ eqs, ModelingToolkit. t_nounits, name = :lotka , systems = [nn ])
33
35
end
34
36
35
37
function lotka_true ()
@@ -41,49 +43,33 @@ function lotka_true()
41
43
Dt (x) ~ α * x - β * x * y,
42
44
Dt (y) ~ - δ * y + δ * x * y
43
45
]
44
- return ODESystem (eqs, ModelingToolkit. t_nounits, name = :lotka_true )
46
+ return System (eqs, ModelingToolkit. t_nounits, name = :lotka_true )
45
47
end
46
48
47
- model = lotka_ude ()
48
-
49
- chain = multi_layer_feed_forward (2 , 2 )
50
- @named nn = NeuralNetworkBlock (2 , 2 ; chain, rng = StableRNG (42 ))
51
-
52
- eqs = [connect (model. nn_in, nn. output)
53
- connect (model. nn_out, nn. input)]
54
- eqs = [model. nn_in. u ~ nn. output. u, model. nn_out. u ~ nn. input. u]
55
- ude_sys = complete (ODESystem (
56
- eqs, ModelingToolkit. t_nounits, systems = [model, nn],
57
- name = :ude_sys ))
49
+ ude_sys = lotka_ude ()
58
50
59
- sys = structural_simplify (ude_sys, allow_symbolic = true )
51
+ sys = mtkcompile (ude_sys, allow_symbolic = true )
60
52
61
- prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1.0 ), [] )
53
+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1.0 ))
62
54
63
- model_true = structural_simplify (lotka_true ())
64
- prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ), [] )
65
- sol_ref = solve (prob_true, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 )
55
+ model_true = mtkcompile (lotka_true ())
56
+ prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ))
57
+ sol_ref = solve (prob_true, Vern9 (), abstol = 1e-10 , reltol = 1e-8 )
66
58
67
- x0 = default_values (sys)[nn. p]
59
+ x0 = default_values (sys)[sys . nn. p]
68
60
69
- get_vars = getu (sys, [sys. lotka . x, sys. lotka . y])
61
+ get_vars = getu (sys, [sys. x, sys. y])
70
62
get_refs = getu (model_true, [model_true. x, model_true. y])
71
- set_x = setp_oop (sys, nn. p)
63
+ set_x = setp_oop (sys, sys . nn. p)
72
64
73
65
function loss (x, (prob, sol_ref, get_vars, get_refs, set_x))
74
66
new_p = set_x (prob, x)
75
67
new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob. u0))
76
68
ts = sol_ref. t
77
- new_sol = solve (new_prob, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
78
-
79
- loss = zero (eltype (x))
80
-
81
- for i in eachindex (new_sol. u)
82
- loss += sum (abs2 .(get_vars (new_sol, i) .- get_refs (sol_ref, i)))
83
- end
69
+ new_sol = solve (new_prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
84
70
85
71
if SciMLBase. successful_retcode (new_sol)
86
- loss
72
+ mean ( abs2 .( reduce (hcat, get_vars (new_sol)) .- reduce (hcat, get_refs (sol_ref))))
87
73
else
88
74
Inf
89
75
end
@@ -103,8 +89,8 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
103
89
@test all (.! isnan .(∇l1))
104
90
@test ! iszero (∇l1)
105
91
106
- @test ∇l1≈ ∇l2 rtol= 1e-3
107
- @test ∇l1≈ ∇l3 rtol = 1e-5
92
+ @test ∇l1≈ ∇l2 rtol= 1e-5
93
+ @test ∇l1 ≈ ∇l3
108
94
109
95
op = OptimizationProblem (of, x0, ps)
110
96
@@ -124,15 +110,16 @@ op = OptimizationProblem(of, x0, ps)
124
110
125
111
res = solve (op, Adam (), maxiters = 10000 )# , callback = plot_cb)
126
112
113
+ display (res. stats)
127
114
@test res. objective < 1
128
115
129
116
res_p = set_x (prob, res. u)
130
117
res_prob = remake (prob, p = res_p)
131
- res_sol = solve (res_prob, Rodas4 (), saveat = sol_ref . t )
118
+ res_sol = solve (res_prob, Vern9 () )
132
119
133
120
# using Plots
134
121
# plot(sol_ref, idxs = [model_true.x, model_true.y])
135
- # plot!(res_sol, idxs = [sys.lotka. x, sys.lotka .y])
122
+ # plot!(res_sol, idxs = [sys.x, sys.y])
136
123
137
124
@test SciMLBase. successful_retcode (res_sol)
138
125
@@ -146,14 +133,14 @@ function lotka_ude2()
146
133
eqs = [pred ~ NN ([x, y], p)
147
134
Dt (x) ~ α * x + pred[1 ]
148
135
Dt (y) ~ - δ * y + pred[2 ]]
149
- return ODESystem (eqs, ModelingToolkit. t_nounits, name = :lotka )
136
+ return System (eqs, ModelingToolkit. t_nounits, name = :lotka )
150
137
end
151
138
152
- sys2 = structural_simplify (lotka_ude2 ())
139
+ sys2 = mtkcompile (lotka_ude2 ())
153
140
154
- prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1.0 ), [] )
141
+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1.0 ))
155
142
156
- sol = solve (prob, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 )
143
+ sol = solve (prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 )
157
144
158
145
@test SciMLBase. successful_retcode (sol)
159
146
0 commit comments