@@ -3,7 +3,7 @@ using JET
33using ModelingToolkitNeuralNets
44using ModelingToolkit
55using ModelingToolkitStandardLibrary. Blocks
6- using OrdinaryDiffEq
6+ using OrdinaryDiffEqVerner
77using SymbolicIndexingInterface
88using Optimization
99using OptimizationOptimisers: Adam
@@ -14,22 +14,24 @@ using StableRNGs
1414using DifferentiationInterface
1515using SciMLSensitivity
1616using Zygote: Zygote
17+ using Statistics
1718
1819function lotka_ude ()
1920 @variables t x (t)= 3.1 y (t)= 1.5
2021 @parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
2122 Dt = ModelingToolkit. D_nounits
22- @named nn_in = RealInputArray (nin = 2 )
23- @named nn_out = RealOutputArray (nout = 2 )
23+
24+ chain = multi_layer_feed_forward (2 , 2 )
25+ @named nn = NeuralNetworkBlock (2 , 2 ; chain, rng = StableRNG (42 ))
2426
2527 eqs = [
26- Dt (x) ~ α * x + nn_in . u [1 ],
27- Dt (y) ~ - δ * y + nn_in . u [2 ],
28- nn_out . u [1 ] ~ x,
29- nn_out . u [2 ] ~ y
28+ Dt (x) ~ α * x + nn . outputs [1 ],
29+ Dt (y) ~ - δ * y + nn . outputs [2 ],
30+ nn . inputs [1 ] ~ x,
31+ nn . inputs [2 ] ~ y
3032 ]
31- return ODESystem (
32- eqs, ModelingToolkit. t_nounits, name = :lotka , systems = [nn_in, nn_out ])
33+ return System (
34+ eqs, ModelingToolkit. t_nounits, name = :lotka , systems = [nn ])
3335end
3436
3537function lotka_true ()
@@ -41,49 +43,33 @@ function lotka_true()
4143 Dt (x) ~ α * x - β * x * y,
4244 Dt (y) ~ - δ * y + δ * x * y
4345 ]
44- return ODESystem (eqs, ModelingToolkit. t_nounits, name = :lotka_true )
46+ return System (eqs, ModelingToolkit. t_nounits, name = :lotka_true )
4547end
4648
47- model = lotka_ude ()
48-
49- chain = multi_layer_feed_forward (2 , 2 )
50- @named nn = NeuralNetworkBlock (2 , 2 ; chain, rng = StableRNG (42 ))
51-
52- eqs = [connect (model. nn_in, nn. output)
53- connect (model. nn_out, nn. input)]
54- eqs = [model. nn_in. u ~ nn. output. u, model. nn_out. u ~ nn. input. u]
55- ude_sys = complete (ODESystem (
56- eqs, ModelingToolkit. t_nounits, systems = [model, nn],
57- name = :ude_sys ))
49+ ude_sys = lotka_ude ()
5850
59- sys = structural_simplify (ude_sys, allow_symbolic = true )
51+ sys = mtkcompile (ude_sys, allow_symbolic = true )
6052
61- prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1.0 ), [] )
53+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1.0 ))
6254
63- model_true = structural_simplify (lotka_true ())
64- prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ), [] )
65- sol_ref = solve (prob_true, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 )
55+ model_true = mtkcompile (lotka_true ())
56+ prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ))
57+ sol_ref = solve (prob_true, Vern9 (), abstol = 1e-10 , reltol = 1e-8 )
6658
67- x0 = default_values (sys)[nn. p]
59+ x0 = default_values (sys)[sys . nn. p]
6860
69- get_vars = getu (sys, [sys. lotka . x, sys. lotka . y])
61+ get_vars = getu (sys, [sys. x, sys. y])
7062get_refs = getu (model_true, [model_true. x, model_true. y])
71- set_x = setp_oop (sys, nn. p)
63+ set_x = setp_oop (sys, sys . nn. p)
7264
7365function loss (x, (prob, sol_ref, get_vars, get_refs, set_x))
7466 new_p = set_x (prob, x)
7567 new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob. u0))
7668 ts = sol_ref. t
77- new_sol = solve (new_prob, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
78-
79- loss = zero (eltype (x))
80-
81- for i in eachindex (new_sol. u)
82- loss += sum (abs2 .(get_vars (new_sol, i) .- get_refs (sol_ref, i)))
83- end
69+ new_sol = solve (new_prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
8470
8571 if SciMLBase. successful_retcode (new_sol)
86- loss
72+ mean ( abs2 .( reduce (hcat, get_vars (new_sol)) .- reduce (hcat, get_refs (sol_ref))))
8773 else
8874 Inf
8975 end
@@ -103,8 +89,8 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
10389@test all (.! isnan .(∇l1))
10490@test ! iszero (∇l1)
10591
106- @test ∇l1≈ ∇l2 rtol= 1e-3
107- @test ∇l1≈ ∇l3 rtol = 1e-5
92+ @test ∇l1≈ ∇l2 rtol= 1e-5
93+ @test ∇l1 ≈ ∇l3
10894
10995op = OptimizationProblem (of, x0, ps)
11096
@@ -124,15 +110,16 @@ op = OptimizationProblem(of, x0, ps)
124110
125111res = solve (op, Adam (), maxiters = 10000 )# , callback = plot_cb)
126112
113+ display (res. stats)
127114@test res. objective < 1
128115
129116res_p = set_x (prob, res. u)
130117res_prob = remake (prob, p = res_p)
131- res_sol = solve (res_prob, Rodas4 (), saveat = sol_ref . t )
118+ res_sol = solve (res_prob, Vern9 () )
132119
133120# using Plots
134121# plot(sol_ref, idxs = [model_true.x, model_true.y])
135- # plot!(res_sol, idxs = [sys.lotka. x, sys.lotka .y])
122+ # plot!(res_sol, idxs = [sys.x, sys.y])
136123
137124@test SciMLBase. successful_retcode (res_sol)
138125
@@ -146,14 +133,14 @@ function lotka_ude2()
146133 eqs = [pred ~ NN ([x, y], p)
147134 Dt (x) ~ α * x + pred[1 ]
148135 Dt (y) ~ - δ * y + pred[2 ]]
149- return ODESystem (eqs, ModelingToolkit. t_nounits, name = :lotka )
136+ return System (eqs, ModelingToolkit. t_nounits, name = :lotka )
150137end
151138
152- sys2 = structural_simplify (lotka_ude2 ())
139+ sys2 = mtkcompile (lotka_ude2 ())
153140
154- prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1.0 ), [] )
141+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys2, [], (0 , 1.0 ))
155142
156- sol = solve (prob, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 )
143+ sol = solve (prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 )
157144
158145@test SciMLBase. successful_retcode (sol)
159146
0 commit comments