Skip to content

Commit fcaf378

Browse files
committed
docs: update docs
1 parent bca35ac commit fcaf378

File tree

3 files changed

+33
-41
lines changed

3 files changed

+33
-41
lines changed

docs/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1515
[compat]
1616
Documenter = "1.3"
1717
Lux = "1"
18-
ModelingToolkit = "9.9"
19-
ModelingToolkitNeuralNets = "1"
18+
ModelingToolkit = "10"
19+
ModelingToolkitNeuralNets = "2"
2020
ModelingToolkitStandardLibrary = "2.7"
2121
Optimization = "3.24, 4.0"
2222
OptimizationOptimisers = "0.2.1, 0.3"
@@ -27,4 +27,4 @@ StableRNGs = "1"
2727
SymbolicIndexingInterface = "0.3.15"
2828

2929
[sources]
30-
ModelingToolkitNeuralNets = { path = ".." }
30+
ModelingToolkitNeuralNets = {path = ".."}

docs/src/friction.md

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ function friction_true()
4949
eqs = [
5050
Dt(y) ~ Fu - friction(y)
5151
]
52-
return ODESystem(eqs, t, name = :friction_true)
52+
return System(eqs, t, name = :friction_true)
5353
end
5454
```
5555

5656
Now that we have defined the model, we will simulate it from 0 to 0.1 seconds.
5757

5858
```@example friction
59-
model_true = structural_simplify(friction_true())
60-
prob_true = ODEProblem(model_true, [], (0, 0.1), [])
61-
sol_ref = solve(prob_true, Rodas4(); saveat = 0.001)
59+
model_true = mtkcompile(friction_true())
60+
prob_true = ODEProblem(model_true, [], (0, 0.1))
61+
sol_ref = solve(prob_true, Vern7(); saveat = 0.001)
6262
```
6363

6464
Let's plot it.
@@ -81,28 +81,23 @@ Now, we will try to learn the same friction model using a neural network. We wil
8181
function friction_ude(Fu)
8282
@variables y(t) = 0.0
8383
@constants Fu = Fu
84-
@named nn_in = RealInputArray(nin = 1)
85-
@named nn_out = RealOutputArray(nout = 1)
86-
eqs = [Dt(y) ~ Fu - nn_in.u[1]
87-
y ~ nn_out.u[1]]
88-
return ODESystem(eqs, t, name = :friction, systems = [nn_in, nn_out])
89-
end
9084
91-
Fu = 120.0
92-
model = friction_ude(Fu)
85+
chain = Lux.Chain(
86+
Lux.Dense(1 => 10, Lux.mish, use_bias = false),
87+
Lux.Dense(10 => 10, Lux.mish, use_bias = false),
88+
Lux.Dense(10 => 1, use_bias = false)
89+
)
90+
@named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
9391
94-
chain = Lux.Chain(
95-
Lux.Dense(1 => 10, Lux.mish, use_bias = false),
96-
Lux.Dense(10 => 10, Lux.mish, use_bias = false),
97-
Lux.Dense(10 => 1, use_bias = false)
98-
)
99-
@named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
92+
eqs = [Dt(y) ~ Fu - nn.outputs[1]
93+
y ~ nn.inputs[1]]
94+
return System(eqs, t, name = :friction, systems = [nn])
95+
end
10096
101-
eqs = [connect(model.nn_in, nn.output)
102-
connect(model.nn_out, nn.input)]
97+
Fu = 120.0
10398
104-
ude_sys = complete(ODESystem(eqs, t, systems = [model, nn], name = :ude_sys))
105-
sys = structural_simplify(ude_sys)
99+
ude_sys = friction_ude(Fu)
100+
sys = mtkcompile(ude_sys)
106101
```
107102

108103
## Optimization Setup
@@ -114,22 +109,19 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
114109
new_p = set_x(prob, x)
115110
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
116111
ts = sol_ref.t
117-
new_sol = solve(new_prob, Rodas4(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
118-
loss = zero(eltype(x))
119-
for i in eachindex(new_sol.u)
120-
loss += sum(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))
121-
end
112+
new_sol = solve(new_prob, Vern7(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
113+
122114
if SciMLBase.successful_retcode(new_sol)
123-
loss
115+
mean(abs2.(reduce(hcat, get_vars(new_sol)) .- reduce(hcat, get_refs(sol_ref))))
124116
else
125117
Inf
126118
end
127119
end
128120
129-
of = OptimizationFunction{true}(loss, AutoForwardDiff())
121+
of = OptimizationFunction(loss, AutoForwardDiff())
130122
131-
prob = ODEProblem(sys, [], (0, 0.1), [])
132-
get_vars = getu(sys, [sys.friction.y])
123+
prob = ODEProblem(sys, [], (0, 0.1))
124+
get_vars = getu(sys, [sys.y])
133125
get_refs = getu(model_true, [model_true.y])
134126
set_x = setp_oop(sys, sys.nn.p)
135127
x0 = default_values(sys)[sys.nn.p]
@@ -150,31 +142,31 @@ We now have a trained neural network! We can check whether running the simulatio
150142
```@example friction
151143
res_p = set_x(prob, res.u)
152144
res_prob = remake(prob, p = res_p)
153-
res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
145+
res_sol = solve(res_prob, Vern7(), saveat = sol_ref.t)
154146
@test first.(sol_ref.u)≈first.(res_sol.u) rtol=1e-3 #hide
155-
@test friction.(first.(sol_ref.u))≈(getindex.(res_sol[sys.nn.output.u], 1)) rtol=1e-1 #hide
147+
@test friction.(first.(sol_ref.u))≈(getindex.(res_sol[sys.nn.outputs], 1)) rtol=1e-1 #hide
156148
nothing #hide
157149
```
158150

159151
Also, it would be interesting to check the simulation before the training to get an idea of the starting point of the network.
160152

161153
```@example friction
162-
initial_sol = solve(prob, Rodas4(), saveat = sol_ref.t)
154+
initial_sol = solve(prob, Vern7(), saveat = sol_ref.t)
163155
```
164156

165157
Now we plot it.
166158

167159
```@example friction
168160
scatter(sol_ref, idxs = [model_true.y], label = "ground truth velocity")
169-
plot!(res_sol, idxs = [sys.friction.y], label = "velocity after training")
170-
plot!(initial_sol, idxs = [sys.friction.y], label = "velocity before training")
161+
plot!(res_sol, idxs = [sys.y], label = "velocity after training")
162+
plot!(initial_sol, idxs = [sys.y], label = "velocity before training")
171163
```
172164

173165
It matches the data well! Let's also check the predictions for the friction force and whether the network learnt the friction model or not.
174166

175167
```@example friction
176168
scatter(sol_ref.t, friction.(first.(sol_ref.u)), label = "ground truth friction")
177-
plot!(res_sol.t, getindex.(res_sol[sys.nn.output.u], 1),
169+
plot!(res_sol.t, getindex.(res_sol[sys.nn.outputs], 1),
178170
label = "friction from neural network")
179171
```
180172

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Pkg.add("ModelingToolkitNeuralNets")
2323

2424
- See the [SciML Style Guide](https://github.com/SciML/SciMLStyle) for common coding practices and other style decisions.
2525
- There are a few community forums:
26-
26+
2727
+ The #diffeq-bridged and #sciml-bridged channels in the
2828
[Julia Slack](https://julialang.org/slack/)
2929
+ The #diffeq-bridged and #sciml-bridged channels in the

0 commit comments

Comments
 (0)