@@ -49,16 +49,16 @@ function friction_true()
4949 eqs = [
5050 Dt(y) ~ Fu - friction(y)
5151 ]
52- return ODESystem (eqs, t, name = :friction_true)
52+ return System (eqs, t, name = :friction_true)
5353end
5454```
5555
5656Now that we have defined the model, we will simulate it from 0 to 0.1 seconds.
5757
5858``` @example friction
59- model_true = structural_simplify (friction_true())
60- prob_true = ODEProblem(model_true, [], (0, 0.1), [] )
61- sol_ref = solve(prob_true, Rodas4 (); saveat = 0.001)
59+ model_true = mtkcompile (friction_true())
60+ prob_true = ODEProblem(model_true, [], (0, 0.1))
61+ sol_ref = solve(prob_true, Vern7 (); saveat = 0.001)
6262```
6363
6464Let's plot it.
@@ -81,28 +81,23 @@ Now, we will try to learn the same friction model using a neural network. We wil
8181function friction_ude(Fu)
8282 @variables y(t) = 0.0
8383 @constants Fu = Fu
84- @named nn_in = RealInputArray(nin = 1)
85- @named nn_out = RealOutputArray(nout = 1)
86- eqs = [Dt(y) ~ Fu - nn_in.u[1]
87- y ~ nn_out.u[1]]
88- return ODESystem(eqs, t, name = :friction, systems = [nn_in, nn_out])
89- end
9084
91- Fu = 120.0
92- model = friction_ude(Fu)
85+ chain = Lux.Chain(
86+ Lux.Dense(1 => 10, Lux.mish, use_bias = false),
87+ Lux.Dense(10 => 10, Lux.mish, use_bias = false),
88+ Lux.Dense(10 => 1, use_bias = false)
89+ )
90+ @named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
9391
94- chain = Lux.Chain(
95- Lux.Dense(1 => 10, Lux.mish, use_bias = false),
96- Lux.Dense(10 => 10, Lux.mish, use_bias = false),
97- Lux.Dense(10 => 1, use_bias = false)
98- )
99- @named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
92+ eqs = [Dt(y) ~ Fu - nn.outputs[1]
93+ y ~ nn.inputs[1]]
94+ return System(eqs, t, name = :friction, systems = [nn])
95+ end
10096
101- eqs = [connect(model.nn_in, nn.output)
102- connect(model.nn_out, nn.input)]
97+ Fu = 120.0
10398
104- ude_sys = complete(ODESystem(eqs, t, systems = [model, nn], name = :ude_sys) )
105- sys = structural_simplify (ude_sys)
99+ ude_sys = friction_ude(Fu )
100+ sys = mtkcompile (ude_sys)
106101```
107102
108103## Optimization Setup
@@ -114,22 +109,19 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
114109 new_p = set_x(prob, x)
115110 new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
116111 ts = sol_ref.t
117- new_sol = solve(new_prob, Rodas4(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
118- loss = zero(eltype(x))
119- for i in eachindex(new_sol.u)
120- loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))
121- end
112+ new_sol = solve(new_prob, Vern7(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
113+
122114 if SciMLBase.successful_retcode(new_sol)
123- loss
115+ mean(abs2.(reduce(hcat, get_vars(new_sol)) .- reduce(hcat, get_refs(sol_ref))))
124116 else
125117 Inf
126118 end
127119end
128120
129- of = OptimizationFunction{true} (loss, AutoForwardDiff())
121+ of = OptimizationFunction(loss, AutoForwardDiff())
130122
131- prob = ODEProblem(sys, [], (0, 0.1), [] )
132- get_vars = getu(sys, [sys.friction. y])
123+ prob = ODEProblem(sys, [], (0, 0.1))
124+ get_vars = getu(sys, [sys.y])
133125get_refs = getu(model_true, [model_true.y])
134126set_x = setp_oop(sys, sys.nn.p)
135127x0 = default_values(sys)[sys.nn.p]
@@ -150,31 +142,31 @@ We now have a trained neural network! We can check whether running the simulatio
150142``` @example friction
151143res_p = set_x(prob, res.u)
152144res_prob = remake(prob, p = res_p)
153- res_sol = solve(res_prob, Rodas4 (), saveat = sol_ref.t)
145+ res_sol = solve(res_prob, Vern7 (), saveat = sol_ref.t)
154146@test first.(sol_ref.u)≈first.(res_sol.u) rtol=1e-3 #hide
155- @test friction.(first.(sol_ref.u))≈(getindex.(res_sol[sys.nn.output.u ], 1)) rtol=1e-1 #hide
147+ @test friction.(first.(sol_ref.u))≈(getindex.(res_sol[sys.nn.outputs ], 1)) rtol=1e-1 #hide
156148nothing #hide
157149```
158150
159151Also, it would be interesting to check the simulation before the training to get an idea of the starting point of the network.
160152
161153``` @example friction
162- initial_sol = solve(prob, Rodas4 (), saveat = sol_ref.t)
154+ initial_sol = solve(prob, Vern7 (), saveat = sol_ref.t)
163155```
164156
165157Now we plot it.
166158
167159``` @example friction
168160scatter(sol_ref, idxs = [model_true.y], label = "ground truth velocity")
169- plot!(res_sol, idxs = [sys.friction. y], label = "velocity after training")
170- plot!(initial_sol, idxs = [sys.friction. y], label = "velocity before training")
161+ plot!(res_sol, idxs = [sys.y], label = "velocity after training")
162+ plot!(initial_sol, idxs = [sys.y], label = "velocity before training")
171163```
172164
173165It matches the data well! Let's also check the predictions for the friction force and whether the network learnt the friction model or not.
174166
175167``` @example friction
176168scatter(sol_ref.t, friction.(first.(sol_ref.u)), label = "ground truth friction")
177- plot!(res_sol.t, getindex.(res_sol[sys.nn.output.u ], 1),
169+ plot!(res_sol.t, getindex.(res_sol[sys.nn.outputs ], 1),
178170 label = "friction from neural network")
179171```
180172
0 commit comments