Skip to content

Commit 210ee41

Browse files
Merge pull request #249 from SciML/docsoptv4
Update docs remove extra returns from loss and extra args from callback
2 parents 0be2ba4 + 109fdac commit 210ee41

File tree

9 files changed

+31
-36
lines changed

9 files changed

+31
-36
lines changed

docs/Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ IncompleteLU = "0.2"
6969
Integrals = "4"
7070
LineSearches = "7"
7171
LinearSolve = "2"
72-
Lux = "0.5"
72+
Lux = "1"
7373
LuxCUDA = "0.3"
7474
MCMCChains = "6"
7575
Measurements = "2"
@@ -78,12 +78,12 @@ ModelingToolkit = "9.9"
7878
MultiDocumenter = "0.7"
7979
NeuralPDE = "5.15"
8080
NonlinearSolve = "3"
81-
Optimization = "3"
82-
OptimizationMOI = "0.4"
83-
OptimizationNLopt = "0.2"
84-
OptimizationOptimJL = "0.2, 0.3"
85-
OptimizationOptimisers = "0.2"
86-
OptimizationPolyalgorithms = "0.2"
81+
Optimization = "4"
82+
OptimizationMOI = "0.5"
83+
OptimizationNLopt = "0.3"
84+
OptimizationOptimJL = "0.4"
85+
OptimizationOptimisers = "0.3"
86+
OptimizationPolyalgorithms = "0.3"
8787
OrdinaryDiffEq = "6"
8888
Plots = "1"
8989
SciMLExpectations = "2"

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ makedocs(sitename = "Overview of Julia's SciML",
2525
"https://epubs.siam.org/doi/10.1137/0903023",
2626
"https://bkamins.github.io/julialang/2020/12/24/minilanguage.html",
2727
"https://arxiv.org/abs/2109.06786",
28-
"https://arxiv.org/abs/2001.04385"],
28+
"https://arxiv.org/abs/2001.04385",
29+
"https://code.visualstudio.com/"],
2930
format = Documenter.HTML(assets = ["assets/favicon.ico"],
3031
canonical = "https://docs.sciml.ai/stable/",
3132
mathengine = mathengine),

docs/src/getting_started/fit_simulation.md

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,14 @@ function loss(newp)
9999
newprob = remake(prob, p = newp)
100100
sol = solve(newprob, saveat = 1)
101101
loss = sum(abs2, sol .- xy_data)
102-
return loss, sol
102+
return loss
103103
end
104104
105105
# Define a callback function to monitor optimization progress
106-
function callback(p, l, sol)
106+
function callback(state, l)
107107
display(l)
108+
newprob = remake(prob, p = state.u)
109+
sol = solve(newprob, saveat = 1)
108110
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
109111
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
110112
display(plt)
@@ -278,37 +280,28 @@ function loss(newp)
278280
newprob = remake(prob, p = newp)
279281
sol = solve(newprob, saveat = 1)
280282
l = sum(abs2, sol .- xy_data)
281-
return l, sol
283+
return l
282284
end
283285
```
284286

285-
Notice that our loss function returns the loss value as the first return,
286-
but returns extra information (the ODE solution with the new parameters)
287-
as an extra return argument.
288-
We will explain why this extra return information is helpful in the next section.
289-
290287
### Step 5: Solve the Optimization Problem
291288

292289
This step will look very similar to [the first optimization tutorial](@ref first_opt),
293-
except now we have a new loss function `loss` which returns both the loss value
294-
and the associated ODE solution.
295-
(In the previous tutorial, `L` only returned the loss value.)
296290
The `Optimization.solve` function can accept an optional callback function
297291
to monitor the optimization process using extra arguments returned from `loss`.
298292

299293
The callback syntax is always:
300294

301295
```
302296
callback(
303-
optimization variables,
297+
state,
304298
the current loss value,
305-
other arguments returned from the loss function, ...
306299
)
307300
```
308301

309-
In this case, we will provide the callback the arguments `(p, l, sol)`,
310-
since it always takes the current state of the optimization first (`p`)
311-
then the returns from the loss function (`l, sol`).
302+
In this case, we will provide the callback the arguments `(state, l)`,
303+
since it always takes the current state of the optimization first (`state`)
304+
then the current loss value (`l`).
312305
The return value of the callback function should default to `false`.
313306
`Optimization.solve` will halt if/when the callback function returns `true` instead.
314307
Typically the `return` statement would monitor the loss value
@@ -318,16 +311,18 @@ More details about callbacks in Optimization.jl can be found
318311
[here](https://docs.sciml.ai/Optimization/stable/API/solve/).
319312

320313
```@example odefit
321-
function callback(p, l, sol)
314+
function callback(state, l)
322315
display(l)
316+
newprob = remake(prob, p = state.u)
317+
sol = solve(newprob, saveat = 1)
323318
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
324319
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
325320
display(plt)
326321
return false
327322
end
328323
```
329324

330-
With this callback function, every step of the optimization will display both the loss value and a plot of how the solution compares to the training data.
325+
With this callback function, every step of the optimization will display both the loss value and a plot of how the solution compares to the training data. Since we want to track the fit visually we plot the simulation at each iteration and compare it to the data. This is expensive since it requires an extra `solve` call and a plotting step for each iteration.
331326

332327
Now, just like [the first optimization tutorial](@ref first_opt),
333328
we set up our `OptimizationFunction` and `OptimizationProblem`,

docs/src/highlevels/modeling_languages.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ doing standard molecular dynamics approximations.
5050

5151
## DiffEqFinancial.jl: Financial models for use in the DifferentialEquations ecosystem
5252

53-
The goal of [DiffEqFinancial.jl](https://github.com/SciML/DiffEqFinancial.jl/commits/master) is to be a feature-complete set
53+
The goal of [DiffEqFinancial.jl](https://github.com/SciML/DiffEqFinancial.jl/) is to be a feature-complete set
5454
of solvers for the types of problems found in libraries like QuantLib, such as the Heston process or the
5555
Black-Scholes model.
5656

docs/src/showcase/blackhole.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,8 @@ function loss(NN_params)
490490
prob_nn, RK4(), u0 = u0, p = NN_params, saveat = tsteps, dt = dt, adaptive = false))
491491
pred_waveform = compute_waveform(dt_data, pred, mass_ratio, model_params)[1]
492492
493-
loss = (sum(abs2,
494-
view(waveform, obs_to_use_for_training) .-
495-
view(pred_waveform, obs_to_use_for_training)))
496-
return loss, pred_waveform
493+
loss = ( sum(abs2, view(waveform,obs_to_use_for_training) .- view(pred_waveform,obs_to_use_for_training) ) )
494+
return loss
497495
end
498496
```
499497

@@ -508,10 +506,11 @@ We'll use the following callback to save the history of the loss values.
508506
```@example ude
509507
losses = []
510508
511-
callback(θ, l, pred_waveform; doplot = true) = begin
509+
callback(state, l; doplot = true) = begin
512510
push!(losses, l)
513511
#= Disable plotting as it trains since in docs
514512
display(l)
513+
waveform = compute_waveform(dt_data, soln, mass_ratio, model_params)[1]
515514
# plot current prediction against data
516515
plt = plot(tsteps, waveform,
517516
markershape=:circle, markeralpha = 0.25,

docs/src/showcase/gpu_spde.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ These last two ways enclose the pointer to our cache arrays locally but still pr
302302
function f(du,u,p,t) to the ODE solver.
303303

304304
Now, since PDEs are large, many times we don't care about getting the whole timeseries. Using
305-
the [output controls from DifferentialEquations.jl](https://diffeq.sciml.ai/latest/basics/common_solver_opts.html#Output-Control-1), we can make it only output the final timepoint.
305+
the [output controls from DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/), we can make it only output the final timepoint.
306306

307307
```julia
308308
prob = ODEProblem(f, u0, (0.0, 100.0))

docs/src/showcase/missing_physics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ current loss:
222222
```@example ude
223223
losses = Float64[]
224224
225-
callback = function (p, l)
225+
callback = function (state, l)
226226
push!(losses, l)
227227
if length(losses) % 50 == 0
228228
println("Current loss after $(length(losses)) iterations: $(losses[end])")

docs/src/showcase/pinngpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ prob = discretize(pde_system, discretization)
148148
## Step 6: Solve the Optimization Problem
149149

150150
```@example pinn
151-
callback = function (p, l)
151+
callback = function (state, l)
152152
println("Current loss is: $l")
153153
return false
154154
end

docs/src/showcase/symbolic_analysis.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ Did you implement the DAE incorrectly? No. Is the solver broken? No.
118118

119119
It turns out that this is a property of the DAE that we are attempting to solve.
120120
This kind of DAE is known as an index-3 DAE. For a complete discussion of DAE
121-
index, see [this article](https://www.scholarpedia.org/article/Differential-algebraic_equations).
121+
index, see [this article](http://www.scholarpedia.org/article/Differential-algebraic_equations).
122122
Essentially, the issue here is that we have 4 differential variables (``x``, ``v_x``, ``y``, ``v_y``)
123123
and one algebraic variable ``T`` (which we can know because there is no `D(T)`
124124
term in the equations). An index-1 DAE always satisfies that the Jacobian of

0 commit comments

Comments
 (0)