Skip to content

Commit 23c2a3d

Browse files
Vaibhavdixit02ChrisRackauckas
authored andcommitted
Update docs remove extra returns from loss and extra args from callback
1 parent 0be2ba4 commit 23c2a3d

File tree

4 files changed

+18
-24
lines changed

4 files changed

+18
-24
lines changed

docs/src/getting_started/fit_simulation.md

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,14 @@ function loss(newp)
9999
newprob = remake(prob, p = newp)
100100
sol = solve(newprob, saveat = 1)
101101
loss = sum(abs2, sol .- xy_data)
102-
return loss, sol
102+
return loss
103103
end
104104
105105
# Define a callback function to monitor optimization progress
106-
function callback(p, l, sol)
106+
function callback(state, l)
107107
display(l)
108+
newprob = remake(prob, p = state.u)
109+
sol = solve(newprob, saveat = 1)
108110
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
109111
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
110112
display(plt)
@@ -278,37 +280,28 @@ function loss(newp)
278280
newprob = remake(prob, p = newp)
279281
sol = solve(newprob, saveat = 1)
280282
l = sum(abs2, sol .- xy_data)
281-
return l, sol
283+
return l
282284
end
283285
```
284286

285-
Notice that our loss function returns the loss value as the first return,
286-
but returns extra information (the ODE solution with the new parameters)
287-
as an extra return argument.
288-
We will explain why this extra return information is helpful in the next section.
289-
290287
### Step 5: Solve the Optimization Problem
291288

292289
This step will look very similar to [the first optimization tutorial](@ref first_opt),
293-
except now we have a new loss function `loss` which returns both the loss value
294-
and the associated ODE solution.
295-
(In the previous tutorial, `L` only returned the loss value.)
296290
The `Optimization.solve` function can accept an optional callback function
297291
to monitor the optimization process using extra arguments returned from `loss`.
298292

299293
The callback syntax is always:
300294

301295
```
302296
callback(
303-
optimization variables,
297+
state,
304298
the current loss value,
305-
other arguments returned from the loss function, ...
306299
)
307300
```
308301

309-
In this case, we will provide the callback the arguments `(p, l, sol)`,
310-
since it always takes the current state of the optimization first (`p`)
311-
then the returns from the loss function (`l, sol`).
302+
In this case, we will provide the callback the arguments `(state, l)`,
303+
since it always takes the current state of the optimization first (`state`)
304+
then the current loss value (`l`).
312305
The return value of the callback function should default to `false`.
313306
`Optimization.solve` will halt if/when the callback function returns `true` instead.
314307
Typically the `return` statement would monitor the loss value
@@ -318,8 +311,10 @@ More details about callbacks in Optimization.jl can be found
318311
[here](https://docs.sciml.ai/Optimization/stable/API/solve/).
319312

320313
```@example odefit
321-
function callback(p, l, sol)
314+
function callback(p, l)
322315
display(l)
316+
newprob = remake(prob, p = p)
317+
sol = solve(newprob, saveat = 1)
323318
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
324319
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
325320
display(plt)

docs/src/showcase/blackhole.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,8 @@ function loss(NN_params)
490490
prob_nn, RK4(), u0 = u0, p = NN_params, saveat = tsteps, dt = dt, adaptive = false))
491491
pred_waveform = compute_waveform(dt_data, pred, mass_ratio, model_params)[1]
492492
493-
loss = (sum(abs2,
494-
view(waveform, obs_to_use_for_training) .-
495-
view(pred_waveform, obs_to_use_for_training)))
496-
return loss, pred_waveform
493+
loss = ( sum(abs2, view(waveform,obs_to_use_for_training) .- view(pred_waveform,obs_to_use_for_training) ) )
494+
return loss
497495
end
498496
```
499497

@@ -508,10 +506,11 @@ We'll use the following callback to save the history of the loss values.
508506
```@example ude
509507
losses = []
510508
511-
callback(θ, l, pred_waveform; doplot = true) = begin
509+
callback(state, l; doplot = true) = begin
512510
push!(losses, l)
513511
#= Disable plotting as it trains since in docs
514512
display(l)
513+
waveform = compute_waveform(dt_data, soln, mass_ratio, model_params)[1]
515514
# plot current prediction against data
516515
plt = plot(tsteps, waveform,
517516
markershape=:circle, markeralpha = 0.25,

docs/src/showcase/missing_physics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ current loss:
222222
```@example ude
223223
losses = Float64[]
224224
225-
callback = function (p, l)
225+
callback = function (state, l)
226226
push!(losses, l)
227227
if length(losses) % 50 == 0
228228
println("Current loss after $(length(losses)) iterations: $(losses[end])")

docs/src/showcase/pinngpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ prob = discretize(pde_system, discretization)
148148
## Step 6: Solve the Optimization Problem
149149

150150
```@example pinn
151-
callback = function (p, l)
151+
callback = function (state, l)
152152
println("Current loss is: $l")
153153
return false
154154
end

0 commit comments

Comments
 (0)