Skip to content

Commit f29d5d8

Browse files
update minibatch and sophia docs
1 parent 1947896 commit f29d5d8

File tree

5 files changed

+75
-39
lines changed

5 files changed

+75
-39
lines changed

docs/src/optimization_packages/optimisers.md

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,7 @@ Pkg.add("OptimizationOptimisers");
1212
In addition to the optimisation algorithms provided by the Optimisers.jl package this subpackage
1313
also provides the Sophia optimisation algorithm.
1414

15-
## Local Unconstrained Optimizers
16-
17-
- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information
18-
in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.
19-
20-
+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`
21-
22-
+ `η` is the learning rate
23-
+ `βs` are the decay of momentums
24-
+ `ϵ` is the epsilon value
25-
+ `λ` is the weight decay parameter
26-
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
27-
+ `ρ` is the momentum
28-
+ Defaults:
29-
30-
* `η = 0.001`
31-
* `βs = (0.9, 0.999)`
32-
* `ϵ = 1e-8`
33-
* `λ = 0.1`
34-
* `k = 10`
35-
* `ρ = 0.04`
15+
## List of optimizers
3616

3717
- [`Optimisers.Descent`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Descent): **Classic gradient descent optimizer with learning rate**
3818

@@ -42,6 +22,7 @@ also provides the Sophia optimisation algorithm.
4222
+ Defaults:
4323

4424
* `η = 0.1`
25+
4526
- [`Optimisers.Momentum`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Momentum): **Classic gradient descent optimizer with learning rate and momentum**
4627

4728
+ `solve(problem, Momentum(η, ρ))`

docs/src/optimization_packages/optimization.md

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,35 @@ There are some solvers that are available in the Optimization.jl package directl
44

55
## Methods
66

7-
`LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.
7+
- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.
88

99
This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl.
1010

11+
- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.
12+
13+
+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`
14+
15+
+ `η` is the learning rate
16+
+ `βs` are the decay of momentums
17+
+ `ϵ` is the epsilon value
18+
+ `λ` is the weight decay parameter
19+
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
20+
+ `ρ` is the momentum
21+
+ Defaults:
22+
23+
* `η = 0.001`
24+
* `βs = (0.9, 0.999)`
25+
* `ϵ = 1e-8`
26+
* `λ = 0.1`
27+
* `k = 10`
28+
* `ρ = 0.04`
29+
1130
## Examples
1231

1332
### Unconstrained rosenbrock problem
1433

1534
```@example L-BFGS
35+
1636
using Optimization, Zygote
1737
1838
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
@@ -27,6 +47,7 @@ sol = solve(prob, Optimization.LBFGS())
2747
### With nonlinear and bounds constraints
2848

2949
```@example L-BFGS
50+
3051
function con2_c(res, x, p)
3152
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1]) - 5]
3253
end
@@ -37,3 +58,35 @@ prob = OptimizationProblem(optf, x0, p, lcons = [1.0, -Inf],
3758
ub = [1.0, 1.0])
3859
res = solve(prob, Optimization.LBFGS(), maxiters = 100)
3960
```
61+
62+
### Train NN with Sophia
63+
64+
```@example Sophia
65+
66+
using Optimization, Lux, Zygote, MLUtils, Statistics, Plots
67+
68+
x = rand(10000)
69+
y = sin.(x)
70+
data = MLUtils.DataLoader((x, y), batchsize = 100)
71+
72+
# Define the neural network
73+
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
74+
ps, st = Lux.setup(Random.default_rng(), model)
75+
ps_ca = ComponentArray(ps)
76+
smodel = StatefulLuxLayer{true}(model, nothing, st)
77+
78+
function callback(state, l)
79+
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
80+
return l < 1e-1 ## Terminate if loss is small
81+
end
82+
83+
function loss(ps, data)
84+
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
85+
return sum(abs2, ypred .- data[2])
86+
end
87+
88+
optf = OptimizationFunction(loss, AutoZygote())
89+
prob = OptimizationProblem(optf, ps_ca, data)
90+
91+
res = Optimization.solve(prob, Optimization.Sophia(), callback = callback)
92+
```

docs/src/tutorials/minibatch.md

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# Data Iterators and Minibatching
22

3-
It is possible to solve an optimization problem with batches using a `Flux.Data.DataLoader`, which is passed to `Optimization.solve` with `ncycles`. All data for the batches need to be passed as a tuple of vectors.
3+
It is possible to solve an optimization problem with batches using a `MLUtils.DataLoader`, which is passed to `Optimization.solve` with `ncycles`. All data for the batches need to be passed as a tuple of vectors.
44

55
!!! note
66

77
This example uses the OptimizationOptimisers.jl package. See the
88
[Optimisers.jl page](@ref optimisers) for details on the installation and usage.
99

10-
```@example
11-
using Flux, Optimization, OptimizationOptimisers, OrdinaryDiffEq, SciMLSensitivity
10+
```@example minibatch
11+
12+
using Lux, Optimization, OptimizationOptimisers, OrdinaryDiffEq, SciMLSensitivity, MLUtils
1213
1314
function newtons_cooling(du, u, p, t)
1415
temp = u[1]
@@ -21,14 +22,16 @@ function true_sol(du, u, p, t)
2122
newtons_cooling(du, u, true_p, t)
2223
end
2324
24-
ann = Chain(Dense(1, 8, tanh), Dense(8, 1, tanh))
25-
pp, re = Flux.destructure(ann)
25+
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
26+
ps, st = Lux.setup(Random.default_rng(), model)
27+
ps_ca = ComponentArray(ps)
28+
smodel = StatefulLuxLayer{true}(model, nothing, st)
2629
2730
function dudt_(u, p, t)
28-
re(p)(u) .* u
31+
smodel(u, p) .* u
2932
end
3033
31-
callback = function (state, l, pred; doplot = false) #callback function to observe training
34+
function callback(state, l, pred; doplot = false) #callback function to observe training
3235
display(l)
3336
# plot current prediction against data
3437
if doplot
@@ -53,21 +56,21 @@ function predict_adjoint(fullp, time_batch)
5356
Array(solve(prob, Tsit5(), p = fullp, saveat = time_batch))
5457
end
5558
56-
function loss_adjoint(fullp, batch, time_batch)
59+
function loss_adjoint(fullp, data)
60+
batch, time_batch = data
5761
pred = predict_adjoint(fullp, time_batch)
5862
sum(abs2, batch .- pred), pred
5963
end
6064
6165
k = 10
6266
# Pass the data for the batches as separate vectors wrapped in a tuple
63-
train_loader = Flux.Data.DataLoader((ode_data, t), batchsize = k)
67+
train_loader = MLUtils.DataLoader((ode_data, t), batchsize = k)
6468
6569
numEpochs = 300
6670
l1 = loss_adjoint(pp, train_loader.data[1], train_loader.data[2])[1]
6771
6872
optfun = OptimizationFunction(
69-
(θ, p, batch, time_batch) -> loss_adjoint(θ, batch,
70-
time_batch),
73+
loss_adjoint,
7174
Optimization.AutoZygote())
7275
optprob = OptimizationProblem(optfun, pp)
7376
using IterTools: ncycle

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
5252
cache.solver_args.epochs
5353
end
5454

55-
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
55+
maxiters = Optimization._check_and_convert_maxiters(maxiters)
5656
if maxiters === nothing
5757
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
5858
end

test/minibatch.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,10 @@ optfun = OptimizationFunction(loss_adjoint,
5858
Optimization.AutoZygote())
5959
optprob = OptimizationProblem(optfun, pp, train_loader)
6060

61-
# res1 = Optimization.solve(optprob,
62-
# Optimization.Sophia(; η = 0.5,
63-
# λ = 0.0), callback = callback,
64-
# maxiters = 1000)
65-
# @test 10res1.objective < l1
61+
res1 = Optimization.solve(optprob,
62+
Optimization.Sophia(), callback = callback,
63+
maxiters = 1000)
64+
@test 10res1.objective < l1
6665

6766
optfun = OptimizationFunction(loss_adjoint,
6867
Optimization.AutoForwardDiff())

0 commit comments

Comments
 (0)