Skip to content

Commit c602d13

Browse files
Merge pull request #1128 from SciML/docsoptv4
Update docs remove extra returns from loss and extra args from callback
2 parents 417e46d + 5d458ad commit c602d13

File tree

17 files changed

+108
-60
lines changed

17 files changed

+108
-60
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ PreallocationTools = "0.4.4"
8181
QuadGK = "2.9.1"
8282
Random = "1.10"
8383
RandomNumbers = "1.5.3"
84-
RecursiveArrayTools = "3.18.1"
84+
RecursiveArrayTools = "3.27.2"
8585
Reexport = "1.0"
8686
ReverseDiff = "1.15.1"
8787
SafeTestsets = "0.1.0"

docs/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1414
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1515
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1616
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
17+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1718
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1819
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
1920
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
@@ -23,6 +24,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2324
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
2425
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2526
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
27+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2628
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2729
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
2830
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -45,6 +47,7 @@ Enzyme = "0.12, 0.13"
4547
Flux = "0.14"
4648
ForwardDiff = "0.10"
4749
IterTools = "1"
50+
MLUtils = "0.4"
4851
Lux = "1"
4952
LuxCUDA = "0.3"
5053
Optimization = "3.9, 4"
@@ -56,6 +59,7 @@ Plots = "1.36"
5659
QuadGK = "2.6"
5760
RecursiveArrayTools = "2.32, 3"
5861
ReverseDiff = "1.14"
62+
SciMLBase = "2.58"
5963
SciMLSensitivity = "7.11"
6064
SimpleChains = "0.4"
6165
StaticArrays = "1"

docs/src/examples/dde/delay_diffeq.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838
loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p))
3939
4040
using Plots
41-
callback = function (state, l...; doplot = false)
41+
callback = function (state, l; doplot = false)
4242
display(loss_dde(state.u))
4343
doplot &&
4444
display(plot(
@@ -60,7 +60,7 @@ We define a callback to display the solution at the current parameters for each
6060

6161
```@example dde
6262
using Plots
63-
callback = function (state, l...; doplot = false)
63+
callback = function (state, l; doplot = false)
6464
display(loss_dde(state.u))
6565
doplot &&
6666
display(plot(

docs/src/examples/neural_ode/minibatch.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
```@example
44
using SciMLSensitivity
5-
using DifferentialEquations, Flux, Random, Plots
5+
using DifferentialEquations, Flux, Random, Plots, MLUtils
66
using IterTools: ncycle
77
88
rng = Random.default_rng()
@@ -46,7 +46,7 @@ ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
4646
prob = ODEProblem{false}(dudt_, u0, tspan, θ)
4747
4848
k = 10
49-
train_loader = Flux.Data.DataLoader((ode_data, t), batchsize = k)
49+
train_loader = DataLoader((ode_data, t), batchsize = k)
5050
5151
for (x, y) in train_loader
5252
@show x
@@ -96,7 +96,7 @@ When training a neural network, we need to find the gradient with respect to our
9696
For this example, we will use a very simple ordinary differential equation, newtons law of cooling. We can represent this in Julia like so.
9797

9898
```@example minibatch
99-
using SciMLSensitivity
99+
using SciMLSensitivity, MLUtils
100100
using DifferentialEquations, Flux, Random, Plots
101101
using IterTools: ncycle
102102
@@ -152,7 +152,7 @@ ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
152152
prob = ODEProblem{false}(dudt_, u0, tspan, θ)
153153
154154
k = 10
155-
train_loader = Flux.Data.DataLoader((ode_data, t), batchsize = k)
155+
train_loader = DataLoader((ode_data, t), batchsize = k)
156156
157157
for (x, y) in train_loader
158158
@show x

docs/src/examples/neural_ode/neural_ode_flux.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,22 +114,23 @@ end
114114
function loss_n_ode(θ)
115115
pred = predict_n_ode(θ)
116116
loss = sum(abs2, ode_data .- pred)
117-
loss, pred
117+
return loss
118118
end
119119
120120
loss_n_ode(θ)
121121
122-
callback = function (θ, l, pred; doplot = false) #callback function to observe training
122+
callback = function (state, l; doplot = false) #callback function to observe training
123123
display(l)
124124
# plot current prediction against data
125+
pred = predict_n_ode(state.u)
125126
pl = scatter(t, ode_data[1, :], label = "data")
126127
scatter!(pl, t, pred[1, :], label = "prediction")
127128
display(plot(pl))
128129
return false
129130
end
130131
131132
# Display the ODE with the initial parameter values.
132-
callback(θ, loss_n_ode(θ)...)
133+
callback((; u = θ), loss_n_ode(θ)...)
133134
134135
# use Optimization.jl to solve the problem
135136
adtype = Optimization.AutoZygote()
@@ -143,7 +144,7 @@ result_neuralode = Optimization.solve(optprob,
143144
maxiters = 300)
144145
```
145146

146-
Notice that the advantage of this format is that we can use Optim's optimizers, like
147+
Notice that the advantage of this format is that we can use other optimizers, like
147148
`LBFGS` with a full `Chain` object, for all of Flux's neural networks, like
148149
convolutional neural networks.
149150

docs/src/examples/neural_ode/simplechains.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,20 @@ end
5454
function loss_neuralode(p)
5555
pred = predict_neuralode(p)
5656
loss = sum(abs2, data .- pred)
57-
return loss, pred
57+
return loss
5858
end
5959
```
6060

6161
## Training
6262

6363
The next step is to minimize the loss, so that the NeuralODE gets trained. But in order to be able to do that, we have to be able to backpropagate through the NeuralODE model. Here the backpropagation through the neural network is the easy part, and we get that out of the box with any deep learning package(although not as fast as SimpleChains for the small nn case here). But we have to find a way to first propagate the sensitivities of the loss back, first through the ODE solver and then to the neural network.
6464

65-
The adjoint of a neural ODE can be calculated through the various AD algorithms available in SciMLSensitivity.jl. But working with [StaticArrays](https://docs.sciml.ai/StaticArrays/stable/) in SimpleChains.jl requires a special adjoint method as StaticArrays do not allow any mutation. All the adjoint methods make heavy use of in-place mutation to be performant with the heap allocated normal arrays. For our statically sized, stack allocated StaticArrays, in order to be able to compute the ODE adjoint we need to do everything out of place. Hence, we have specifically used `QuadratureAdjoint(autojacvec=ZygoteVJP())` adjoint algorithm in the solve call inside `predict_neuralode(p)` which computes everything out-of-place when u0 is a StaticArray. Hence, we can move forward with the training of the NeuralODE
65+
The adjoint of a neural ODE can be calculated through the various AD algorithms available in SciMLSensitivity.jl. But working with [StaticArrays](https://juliaarrays.github.io/StaticArrays.jl/stable/) in SimpleChains.jl requires a special adjoint method as StaticArrays do not allow any mutation. All the adjoint methods make heavy use of in-place mutation to be performant with the heap allocated normal arrays. For our statically sized, stack allocated StaticArrays, in order to be able to compute the ODE adjoint we need to do everything out of place. Hence, we have specifically used `QuadratureAdjoint(autojacvec=ZygoteVJP())` adjoint algorithm in the solve call inside `predict_neuralode(p)` which computes everything out-of-place when u0 is a StaticArray. Hence, we can move forward with the training of the NeuralODE
6666

6767
```@example sc_neuralode
68-
callback = function (state, l, pred; doplot = true)
68+
callback = function (state, l; doplot = true)
6969
display(l)
70+
pred = predict_neuralode(state.u)
7071
plt = scatter(tsteps, data[1, :], label = "data")
7172
scatter!(plt, tsteps, pred[1, :], label = "prediction")
7273
if doplot

docs/src/examples/ode/second_order_adjoints.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ end
4949
function loss_neuralode(p)
5050
pred = predict_neuralode(p)
5151
loss = sum(abs2, ode_data .- pred)
52-
return loss, pred
52+
return loss
5353
end
5454
5555
# Callback function to observe training
5656
list_plots = []
5757
iter = 0
58-
callback = function (state, l, pred; doplot = false)
58+
callback = function (state, l; doplot = false)
5959
global list_plots, iter
6060
6161
if iter == 0
@@ -66,6 +66,7 @@ callback = function (state, l, pred; doplot = false)
6666
display(l)
6767
6868
# plot current prediction against data
69+
pred = predict_neuralode(state.u)
6970
plt = scatter(tsteps, ode_data[1, :], label = "data")
7071
scatter!(plt, tsteps, pred[1, :], label = "prediction")
7172
push!(list_plots, plt)

docs/src/examples/ode/second_order_neural.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ t = range(tspan[1], tspan[2], length = 20)
3333
model = Chain(Dense(2, 50, tanh), Dense(50, 2))
3434
ps, st = Lux.setup(Random.default_rng(), model)
3535
ps = ComponentArray(ps)
36-
model = StatefulLuxLayer{true}(model, ps, st)
36+
model = Lux.StatefulLuxLayer{true}(model, ps, st)
3737
3838
ff(du, u, p, t) = model(u, p)
3939
prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, ps)
@@ -46,12 +46,12 @@ correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:
4646
4747
function loss_n_ode(p)
4848
pred = predict(p)
49-
sum(abs2, correct_pos .- pred[1:2, :]), pred
49+
sum(abs2, correct_pos .- pred[1:2, :])
5050
end
5151
5252
l1 = loss_n_ode(ps)
5353
54-
callback = function (state, l, pred)
54+
callback = function (state, l)
5555
println(l)
5656
l < 0.01
5757
end

docs/src/examples/optimal_control/feedback_control.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ l = loss_univ(θ)
6161
```@example udeneuralcontrol
6262
list_plots = []
6363
iter = 0
64-
cb = function (state, l)
64+
cb = function (state, l; makeplot = false)
6565
global list_plots, iter
6666
6767
if iter == 0
@@ -71,9 +71,11 @@ cb = function (state, l)
7171
7272
println(l)
7373
74-
plt = plot(predict_univ(state.u)', ylim = (0, 6))
75-
push!(list_plots, plt)
76-
display(plt)
74+
if makeplot
75+
plt = plot(predict_univ(state.u)', ylim = (0, 6))
76+
push!(list_plots, plt)
77+
display(plt)
78+
end
7779
return false
7880
end
7981
```
@@ -84,3 +86,7 @@ optf = Optimization.OptimizationFunction((x, p) -> loss_univ(x), adtype)
8486
optprob = Optimization.OptimizationProblem(optf, θ)
8587
result_univ = Optimization.solve(optprob, PolyOpt(), callback = cb)
8688
```
89+
90+
```@example udeneuralcontrol
91+
cb(result_univ, result_univ.minimum; makeplot=true)
92+
```

docs/src/examples/optimal_control/optimal_control.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ of a local minimum. This looks like:
3737

3838
```@example neuraloptimalcontrol
3939
using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
40-
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random
40+
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random,
41+
ForwardDiff
4142
4243
rng = Random.default_rng()
4344
tspan = (0.0f0, 8.0f0)
@@ -89,7 +90,7 @@ end
8990
# Setup and run the optimization
9091
9192
loss1 = loss_adjoint(θ)
92-
adtype = Optimization.AutoZygote()
93+
adtype = Optimization.AutoForwardDiff()
9394
optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
9495
9596
optprob = Optimization.OptimizationProblem(optf, θ)

0 commit comments

Comments
 (0)