Skip to content

Commit 2f855d3

Browse files
docs: update doc example
1 parent 4f6cf44 commit 2f855d3

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

docs/src/friction.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ sys = structural_simplify(ude_sys)
110110
We now setup the loss function and the optimization loop.
111111

112112
```@example friction
113-
function loss(x, (prob, sol_ref, get_vars, get_refs))
114-
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
113+
function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
114+
new_p = set_x(prob, x)
115115
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
116116
ts = sol_ref.t
117117
new_sol = solve(new_prob, Rodas4(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
@@ -131,14 +131,15 @@ of = OptimizationFunction{true}(loss, AutoForwardDiff())
131131
prob = ODEProblem(sys, [], (0, 0.1), [])
132132
get_vars = getu(sys, [sys.friction.y])
133133
get_refs = getu(model_true, [model_true.y])
134-
x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys)))
134+
set_x = setp_oop(sys, sys.nn.p)
135+
x0 = default_values(sys)[sys.nn.p]
135136
136137
cb = (opt_state, loss) -> begin
137138
@info "step $(opt_state.iter), loss: $loss"
138139
return false
139140
end
140141
141-
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))
142+
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs, set_x))
142143
res = solve(op, Adam(5e-3); maxiters = 10000, callback = cb)
143144
```
144145

@@ -147,7 +148,7 @@ res = solve(op, Adam(5e-3); maxiters = 10000, callback = cb)
147148
We now have a trained neural network! We can check whether running the simulation of the model embedded with the neural network matches the data or not.
148149

149150
```@example friction
150-
res_p = SciMLStructures.replace(Tunable(), prob.p, res.u)
151+
res_p = set_x(prob, res.u)
151152
res_prob = remake(prob, p = res_p)
152153
res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
153154
@test first.(sol_ref.u)≈first.(res_sol.u) rtol=1e-3 #hide

0 commit comments

Comments
 (0)