Skip to content

Commit 5cf459a

Browse files
separate out fixed parameter and dataloader cases explictly for now
1 parent d09cf00 commit 5cf459a

File tree

6 files changed

+34
-14
lines changed

6 files changed

+34
-14
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ LBFGSB = "5be7bae1-8223-5378-bac3-9e7378a2f6e6"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1313
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
14+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1415
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
1516
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1617
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"

lib/OptimizationOptimJL/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
b = 0.5)); callback = CallbackTester(length(x0)))
4343
@test 10 * sol.objective < l1
4444

45-
f = OptimizationFunction(rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()))
45+
f = OptimizationFunction(rosenbrock, AutoReverseDiff())
4646

4747
Random.seed!(1234)
4848
prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])

lib/OptimizationOptimisers/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Vaibhav Dixit <[email protected]> and contributors"]
44
version = "0.2.1"
55

66
[deps]
7+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
78
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
89
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
910
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module OptimizationOptimisers
22

33
using Reexport, Printf, ProgressLogging
44
@reexport using Optimisers, Optimization
5-
using Optimization.SciMLBase
5+
using Optimization.SciMLBase, MLUtils
66

77
SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true
88
SciMLBase.requiresgradient(opt::AbstractRule) = true
@@ -57,10 +57,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
5757
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
5858
end
5959

60-
if cache.p == SciMLBase.NullParameters()
61-
data = OptimizationBase.DEFAULT_DATA
62-
else
60+
if cache.p isa MLUtils.DataLoader
6361
data = cache.p
62+
dataiterate = true
63+
else
64+
data = [cache.p]
65+
dataiterate = false
6466
end
6567
opt = cache.opt
6668
θ = copy(cache.u0)
@@ -77,11 +79,16 @@ function SciMLBase.__solve(cache::OptimizationCache{
7779
Optimization.@withprogress cache.progress name="Training" begin
7880
for _ in 1:maxiters
7981
for (i, d) in enumerate(data)
80-
if cache.f.fg !== nothing
82+
if cache.f.fg !== nothing && dataiterate
8183
x = cache.f.fg(G, θ, d)
82-
else
84+
elseif dataiterate
8385
cache.f.grad(G, θ, d)
8486
x = cache.f(θ, d)
87+
elseif cache.f.fg !== nothing
88+
x = cache.f.fg(G, θ)
89+
else
90+
cache.f.grad(G, θ)
91+
x = cache.f(θ)
8592
end
8693
opt_state = Optimization.OptimizationState(iter = i,
8794
u = θ,

src/sophia.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
6464

6565
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
6666

67-
if cache.p == SciMLBase.NullParameters()
68-
data = OptimizationBase.DEFAULT_DATA
69-
else
67+
if cache.p isa MLUtils.DataLoader
7068
data = cache.p
69+
dataiterate = true
70+
else
71+
data = [cache.p]
72+
dataiterate = false
7173
end
7274

7375
f = cache.f
@@ -77,14 +79,23 @@ function SciMLBase.__solve(cache::OptimizationCache{
7779
hₜ = zero(θ)
7880
for _ in 1:maxiters
7981
for (i, d) in enumerate(data)
80-
f.grad(gₜ, θ, d)
81-
x = cache.f(θ, d)
82+
if cache.f.fg !== nothing && dataiterate
83+
x = cache.f.fg(G, θ, d)
84+
elseif dataiterate
85+
cache.f.grad(G, θ, d)
86+
x = cache.f(θ, d)
87+
elseif cache.f.fg !== nothing
88+
x = cache.f.fg(G, θ)
89+
else
90+
cache.f.grad(G, θ)
91+
x = cache.f(θ)
92+
end
8293
opt_state = Optimization.OptimizationState(; iter = i,
8394
u = θ,
8495
objective = first(x),
8596
grad = gₜ,
8697
original = nothing)
87-
cb_call = cache.callback(θ, x...)
98+
cb_call = cache.callback(opt_state, x...)
8899
if !(cb_call isa Bool)
89100
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
90101
elseif cb_call

test/minibatch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function dudt_(u, p, t)
1919
ann(u, p, st)[1] .* u
2020
end
2121

22-
function callback(state, l) #callback function to observe training
22+
function callback(state, l, pred) #callback function to observe training
2323
display(l)
2424
return false
2525
end

0 commit comments

Comments
 (0)