Skip to content

Commit 0dab5f5

Browse files
Update optim fg, optimisers extensions and tests
1 parent 4127df2 commit 0dab5f5

File tree

7 files changed

+39
-54
lines changed

7 files changed

+39
-54
lines changed

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ function SciMLBase.requireshessian(opt::Union{
2626
true
2727
end
2828
SciMLBase.requiresgradient(opt::Optim.Fminbox) = true
29+
SciMLBase.allowsfg(opt::Union{Optim.AbstractOptimizer, Optim.ConstrainedOptimizer, Optim.Fminbox, Optim.SAMIN}) = true
2930

3031
function __map_optimizer_args(cache::OptimizationCache,
3132
opt::Union{Optim.AbstractOptimizer, Optim.Fminbox,
@@ -142,11 +143,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
142143
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
143144
opt_state = Optimization.OptimizationState(iter = trace.iteration,
144145
u = θ,
145-
objective = x[1],
146+
objective = trace.value,
146147
grad = get(metadata, "g(x)", nothing),
147148
hess = get(metadata, "h(x)", nothing),
148149
original = trace)
149-
cb_call = cache.callback(opt_state, x...)
150+
cb_call = cache.callback(opt_state, trace.value)
150151
if !(cb_call isa Bool)
151152
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
152153
end
@@ -261,11 +262,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
261262
metadata["x"]
262263
opt_state = Optimization.OptimizationState(iter = trace.iteration,
263264
u = θ,
264-
objective = x[1],
265+
objective = trace.value,
265266
grad = get(metadata, "g(x)", nothing),
266267
hess = get(metadata, "h(x)", nothing),
267268
original = trace)
268-
cb_call = cache.callback(opt_state, x...)
269+
cb_call = cache.callback(opt_state, trace.value)
269270
if !(cb_call isa Bool)
270271
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
271272
end
@@ -277,15 +278,21 @@ function SciMLBase.__solve(cache::OptimizationCache{
277278
__x = first(x)
278279
return cache.sense === Optimization.MaxSense ? -__x : __x
279280
end
280-
fg! = function (G, θ)
281-
if G !== nothing
282-
cache.f.grad(G, θ)
283-
if cache.sense === Optimization.MaxSense
284-
G .*= -one(eltype(G))
281+
282+
if cache.f.fg === nothing
283+
fg! = function (G, θ)
284+
if G !== nothing
285+
cache.f.grad(G, θ)
286+
if cache.sense === Optimization.MaxSense
287+
G .*= -one(eltype(G))
288+
end
285289
end
290+
return _loss(θ)
286291
end
287-
return _loss(θ)
292+
else
293+
fg! = cache.f.fg
288294
end
295+
289296

290297
gg = function (G, θ)
291298
cache.f.grad(G, θ)
@@ -344,9 +351,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
344351
u = metadata["x"],
345352
grad = get(metadata, "g(x)", nothing),
346353
hess = get(metadata, "h(x)", nothing),
347-
objective = x[1],
354+
objective = trace.value,
348355
original = trace)
349-
cb_call = cache.callback(opt_state, x...)
356+
cb_call = cache.callback(opt_state, trace.value)
350357
if !(cb_call isa Bool)
351358
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
352359
end
@@ -358,15 +365,21 @@ function SciMLBase.__solve(cache::OptimizationCache{
358365
__x = first(x)
359366
return cache.sense === Optimization.MaxSense ? -__x : __x
360367
end
361-
fg! = function (G, θ)
362-
if G !== nothing
363-
cache.f.grad(G, θ)
364-
if cache.sense === Optimization.MaxSense
365-
G .*= -one(eltype(G))
368+
369+
if cache.f.fg === nothing
370+
fg! = function (G, θ)
371+
if G !== nothing
372+
cache.f.grad(G, θ)
373+
if cache.sense === Optimization.MaxSense
374+
G .*= -one(eltype(G))
375+
end
366376
end
377+
return _loss(θ)
367378
end
368-
return _loss(θ)
379+
else
380+
fg! = cache.f.fg
369381
end
382+
370383
gg = function (G, θ)
371384
cache.f.grad(G, θ)
372385
if cache.sense === Optimization.MaxSense
@@ -434,7 +447,7 @@ PrecompileTools.@compile_workload begin
434447
function obj_f(x, p)
435448
A = p[1]
436449
b = p[2]
437-
return sum((A * x - b) .^ 2)
450+
return sum((A * x .- b) .^ 2)
438451
end
439452

440453
function solve_nonnegative_least_squares(A, b, solver)

lib/OptimizationOptimisers/Project.toml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1111
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1212

13-
[weakdeps]
14-
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
15-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
16-
17-
[extensions]
18-
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
19-
OptimizationOptimisersMLUtilsExt = "MLUtils"
20-
2113
[compat]
22-
MLDataDevices = "1.1"
23-
MLUtils = "0.4.4"
2414
Optimisers = "0.2, 0.3"
2515
Optimization = "4"
2616
ProgressLogging = "0.1"

lib/OptimizationOptimisers/ext/OptimizationOptimisersMLDataDevicesExt.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

lib/OptimizationOptimisers/ext/OptimizationOptimisersMLUtilsExt.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module OptimizationOptimisers
22

33
using Reexport, Printf, ProgressLogging
44
@reexport using Optimisers, Optimization
5-
using Optimization.SciMLBase
5+
using Optimization.SciMLBase, Optimization.OptimizationBase
66

77
SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true
88
SciMLBase.requiresgradient(opt::AbstractRule) = true
@@ -16,8 +16,6 @@ function SciMLBase.__init(
1616
kwargs...)
1717
end
1818

19-
isa_dataiterator(data) = false
20-
2119
function SciMLBase.__solve(cache::OptimizationCache{
2220
F,
2321
RC,
@@ -59,7 +57,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
5957
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
6058
end
6159

62-
if isa_dataiterator(cache.p)
60+
if OptimizationBase.isa_dataiterator(cache.p)
6361
data = cache.p
6462
dataiterate = true
6563
else

test/diffeqfluxtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ end
3131
function loss_adjoint(p)
3232
prediction = predict_adjoint(p)
3333
loss = sum(abs2, x - 1 for x in prediction)
34-
return loss, prediction
34+
return loss
3535
end
3636

3737
iter = 0
38-
callback = function (state, l, pred)
38+
callback = function (state, l)
3939
display(l)
4040

4141
# using `remake` to re-create our `prob` with current parameters `p`
@@ -81,11 +81,11 @@ end
8181
function loss_neuralode(p)
8282
pred = predict_neuralode(p)
8383
loss = sum(abs2, ode_data .- pred)
84-
return loss, pred
84+
return loss
8585
end
8686

8787
iter = 0
88-
callback = function (st, l, pred...)
88+
callback = function (st, l)
8989
global iter
9090
iter += 1
9191

test/minibatch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
function loss_adjoint(fullp, p)
4646
(batch, time_batch) = p
4747
pred = predict_adjoint(fullp, time_batch)
48-
sum(abs2, batch .- pred), pred
48+
sum(abs2, batch .- pred)
4949
end
5050

5151
k = 10

0 commit comments

Comments
 (0)