@@ -49,16 +49,16 @@ function friction_true()
49
49
eqs = [
50
50
Dt(y) ~ Fu - friction(y)
51
51
]
52
- return ODESystem (eqs, t, name = :friction_true)
52
+ return System (eqs, t, name = :friction_true)
53
53
end
54
54
```
55
55
56
56
Now that we have defined the model, we will simulate it from 0 to 0.1 seconds.
57
57
58
58
``` @example friction
59
- model_true = structural_simplify (friction_true())
60
- prob_true = ODEProblem(model_true, [], (0, 0.1), [] )
61
- sol_ref = solve(prob_true, Rodas4 (); saveat = 0.001)
59
+ model_true = mtkcompile (friction_true())
60
+ prob_true = ODEProblem(model_true, [], (0, 0.1))
61
+ sol_ref = solve(prob_true, Vern7 (); saveat = 0.001)
62
62
```
63
63
64
64
Let's plot it.
@@ -81,28 +81,23 @@ Now, we will try to learn the same friction model using a neural network. We wil
81
81
function friction_ude(Fu)
82
82
@variables y(t) = 0.0
83
83
@constants Fu = Fu
84
- @named nn_in = RealInputArray(nin = 1)
85
- @named nn_out = RealOutputArray(nout = 1)
86
- eqs = [Dt(y) ~ Fu - nn_in.u[1]
87
- y ~ nn_out.u[1]]
88
- return ODESystem(eqs, t, name = :friction, systems = [nn_in, nn_out])
89
- end
90
84
91
- Fu = 120.0
92
- model = friction_ude(Fu)
85
+ chain = Lux.Chain(
86
+ Lux.Dense(1 => 10, Lux.mish, use_bias = false),
87
+ Lux.Dense(10 => 10, Lux.mish, use_bias = false),
88
+ Lux.Dense(10 => 1, use_bias = false)
89
+ )
90
+ @named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
93
91
94
- chain = Lux.Chain(
95
- Lux.Dense(1 => 10, Lux.mish, use_bias = false),
96
- Lux.Dense(10 => 10, Lux.mish, use_bias = false),
97
- Lux.Dense(10 => 1, use_bias = false)
98
- )
99
- @named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
92
+ eqs = [Dt(y) ~ Fu - nn.outputs[1]
93
+ y ~ nn.inputs[1]]
94
+ return System(eqs, t, name = :friction, systems = [nn])
95
+ end
100
96
101
- eqs = [connect(model.nn_in, nn.output)
102
- connect(model.nn_out, nn.input)]
97
+ Fu = 120.0
103
98
104
- ude_sys = complete(ODESystem(eqs, t, systems = [model, nn], name = :ude_sys) )
105
- sys = structural_simplify (ude_sys)
99
+ ude_sys = friction_ude(Fu )
100
+ sys = mtkcompile (ude_sys)
106
101
```
107
102
108
103
## Optimization Setup
@@ -114,22 +109,19 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
114
109
new_p = set_x(prob, x)
115
110
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
116
111
ts = sol_ref.t
117
- new_sol = solve(new_prob, Rodas4(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
118
- loss = zero(eltype(x))
119
- for i in eachindex(new_sol.u)
120
- loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))
121
- end
112
+ new_sol = solve(new_prob, Vern7(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
113
+
122
114
if SciMLBase.successful_retcode(new_sol)
123
- loss
115
+ mean(abs2.(reduce(hcat, get_vars(new_sol)) .- reduce(hcat, get_refs(sol_ref))))
124
116
else
125
117
Inf
126
118
end
127
119
end
128
120
129
- of = OptimizationFunction{true} (loss, AutoForwardDiff())
121
+ of = OptimizationFunction(loss, AutoForwardDiff())
130
122
131
- prob = ODEProblem(sys, [], (0, 0.1), [] )
132
- get_vars = getu(sys, [sys.friction. y])
123
+ prob = ODEProblem(sys, [], (0, 0.1))
124
+ get_vars = getu(sys, [sys.y])
133
125
get_refs = getu(model_true, [model_true.y])
134
126
set_x = setp_oop(sys, sys.nn.p)
135
127
x0 = default_values(sys)[sys.nn.p]
@@ -150,31 +142,31 @@ We now have a trained neural network! We can check whether running the simulatio
150
142
``` @example friction
151
143
res_p = set_x(prob, res.u)
152
144
res_prob = remake(prob, p = res_p)
153
- res_sol = solve(res_prob, Rodas4 (), saveat = sol_ref.t)
145
+ res_sol = solve(res_prob, Vern7 (), saveat = sol_ref.t)
154
146
@test first.(sol_ref.u)≈first.(res_sol.u) rtol=1e-3 #hide
155
- @test friction.(first.(sol_ref.u))≈(getindex.(res_sol[sys.nn.output.u ], 1)) rtol=1e-1 #hide
147
+ @test friction.(first.(sol_ref.u))≈(getindex.(res_sol[sys.nn.outputs ], 1)) rtol=1e-1 #hide
156
148
nothing #hide
157
149
```
158
150
159
151
Also, it would be interesting to check the simulation before the training to get an idea of the starting point of the network.
160
152
161
153
``` @example friction
162
- initial_sol = solve(prob, Rodas4 (), saveat = sol_ref.t)
154
+ initial_sol = solve(prob, Vern7 (), saveat = sol_ref.t)
163
155
```
164
156
165
157
Now we plot it.
166
158
167
159
``` @example friction
168
160
scatter(sol_ref, idxs = [model_true.y], label = "ground truth velocity")
169
- plot!(res_sol, idxs = [sys.friction. y], label = "velocity after training")
170
- plot!(initial_sol, idxs = [sys.friction. y], label = "velocity before training")
161
+ plot!(res_sol, idxs = [sys.y], label = "velocity after training")
162
+ plot!(initial_sol, idxs = [sys.y], label = "velocity before training")
171
163
```
172
164
173
165
It matches the data well! Let's also check the predictions for the friction force and whether the network learnt the friction model or not.
174
166
175
167
``` @example friction
176
168
scatter(sol_ref.t, friction.(first.(sol_ref.u)), label = "ground truth friction")
177
- plot!(res_sol.t, getindex.(res_sol[sys.nn.output.u ], 1),
169
+ plot!(res_sol.t, getindex.(res_sol[sys.nn.outputs ], 1),
178
170
label = "friction from neural network")
179
171
```
180
172
0 commit comments