@@ -18,6 +18,7 @@ using OptimizationOptimisers: Adam
1818using SciMLStructures
1919using SciMLStructures: Tunable
2020using SymbolicIndexingInterface
21+ using Statistics
2122using StableRNGs
2223using Lux
2324using Plots
@@ -49,16 +50,16 @@ function friction_true()
4950 eqs = [
5051 Dt(y) ~ Fu - friction(y)
5152 ]
52- return ODESystem (eqs, t, name = :friction_true)
53+ return System (eqs, t, name = :friction_true)
5354end
5455```
5556
5657Now that we have defined the model, we will simulate it from 0 to 0.1 seconds.
5758
5859``` @example friction
59- model_true = structural_simplify (friction_true())
60- prob_true = ODEProblem(model_true, [], (0, 0.1), [] )
61- sol_ref = solve(prob_true, Rodas4 (); saveat = 0.001)
60+ model_true = mtkcompile (friction_true())
61+ prob_true = ODEProblem(model_true, [], (0, 0.1))
62+ sol_ref = solve(prob_true, Vern7 (); saveat = 0.001)
6263```
6364
6465Let's plot it.
@@ -81,28 +82,23 @@ Now, we will try to learn the same friction model using a neural network. We wil
8182function friction_ude(Fu)
8283 @variables y(t) = 0.0
8384 @constants Fu = Fu
84- @named nn_in = RealInputArray(nin = 1)
85- @named nn_out = RealOutputArray(nout = 1)
86- eqs = [Dt(y) ~ Fu - nn_in.u[1]
87- y ~ nn_out.u[1]]
88- return ODESystem(eqs, t, name = :friction, systems = [nn_in, nn_out])
89- end
9085
91- Fu = 120.0
92- model = friction_ude(Fu)
86+ chain = Lux.Chain(
87+ Lux.Dense(1 => 10, Lux.mish, use_bias = false),
88+ Lux.Dense(10 => 10, Lux.mish, use_bias = false),
89+ Lux.Dense(10 => 1, use_bias = false)
90+ )
91+ @named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
9392
94- chain = Lux.Chain(
95- Lux.Dense(1 => 10, Lux.mish, use_bias = false),
96- Lux.Dense(10 => 10, Lux.mish, use_bias = false),
97- Lux.Dense(10 => 1, use_bias = false)
98- )
99- @named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
93+ eqs = [Dt(y) ~ Fu - nn.outputs[1]
94+ y ~ nn.inputs[1]]
95+ return System(eqs, t, name = :friction, systems = [nn])
96+ end
10097
101- eqs = [connect(model.nn_in, nn.output)
102- connect(model.nn_out, nn.input)]
98+ Fu = 120.0
10399
104- ude_sys = complete(ODESystem(eqs, t, systems = [model, nn], name = :ude_sys) )
105- sys = structural_simplify (ude_sys)
100+ ude_sys = friction_ude(Fu )
101+ sys = mtkcompile (ude_sys)
106102```
107103
108104## Optimization Setup
@@ -114,22 +110,19 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
114110 new_p = set_x(prob, x)
115111 new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
116112 ts = sol_ref.t
117- new_sol = solve(new_prob, Rodas4(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
118- loss = zero(eltype(x))
119- for i in eachindex(new_sol.u)
120- loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))
121- end
113+ new_sol = solve(new_prob, Vern7(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
114+
122115 if SciMLBase.successful_retcode(new_sol)
123- loss
116+ mean(abs2.(reduce(hcat, get_vars(new_sol)) .- reduce(hcat, get_refs(sol_ref))))
124117 else
125118 Inf
126119 end
127120end
128121
129- of = OptimizationFunction{true} (loss, AutoForwardDiff())
122+ of = OptimizationFunction(loss, AutoForwardDiff())
130123
131- prob = ODEProblem(sys, [], (0, 0.1), [] )
132- get_vars = getu(sys, [sys.friction. y])
124+ prob = ODEProblem(sys, [], (0, 0.1))
125+ get_vars = getu(sys, [sys.y])
133126get_refs = getu(model_true, [model_true.y])
134127set_x = setp_oop(sys, sys.nn.p)
135128x0 = default_values(sys)[sys.nn.p]
@@ -150,31 +143,31 @@ We now have a trained neural network! We can check whether running the simulatio
150143``` @example friction
151144res_p = set_x(prob, res.u)
152145res_prob = remake(prob, p = res_p)
153- res_sol = solve(res_prob, Rodas4 (), saveat = sol_ref.t)
146+ res_sol = solve(res_prob, Vern7 (), saveat = sol_ref.t)
154147@test first.(sol_ref.u)≈first.(res_sol.u) rtol=1e-3 #hide
155- @test friction.(first.(sol_ref.u))≈(getindex.(res_sol[sys.nn.output.u ], 1)) rtol=1e-1 #hide
148+ @test friction.(first.(sol_ref.u))≈(getindex.(res_sol[sys.nn.outputs ], 1)) rtol=1e-1 #hide
156149nothing #hide
157150```
158151
159152Also, it would be interesting to check the simulation before the training to get an idea of the starting point of the network.
160153
161154``` @example friction
162- initial_sol = solve(prob, Rodas4 (), saveat = sol_ref.t)
155+ initial_sol = solve(prob, Vern7 (), saveat = sol_ref.t)
163156```
164157
165158Now we plot it.
166159
167160``` @example friction
168161scatter(sol_ref, idxs = [model_true.y], label = "ground truth velocity")
169- plot!(res_sol, idxs = [sys.friction. y], label = "velocity after training")
170- plot!(initial_sol, idxs = [sys.friction. y], label = "velocity before training")
162+ plot!(res_sol, idxs = [sys.y], label = "velocity after training")
163+ plot!(initial_sol, idxs = [sys.y], label = "velocity before training")
171164```
172165
173166It matches the data well! Let's also check the predictions for the friction force and whether the network learnt the friction model or not.
174167
175168``` @example friction
176169scatter(sol_ref.t, friction.(first.(sol_ref.u)), label = "ground truth friction")
177- plot!(res_sol.t, getindex.(res_sol[sys.nn.output.u ], 1),
170+ plot!(res_sol.t, getindex.(res_sol[sys.nn.outputs ], 1),
178171 label = "friction from neural network")
179172```
180173
0 commit comments