Skip to content

Commit 951d661

Browse files
updates for CI
1 parent 4a9737c commit 951d661

File tree

6 files changed

+44
-43
lines changed

6 files changed

+44
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ LBFGSB = "0.4.1"
2828
LinearAlgebra = "1.10"
2929
Logging = "1.10"
3030
LoggingExtras = "0.4, 1"
31-
OptimizationBase = "2.0.1"
31+
OptimizationBase = "2.0.2"
3232
Printf = "1.10"
3333
ProgressLogging = "0.1"
3434
Reexport = "1.2"

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ internal state.
1313
abstract type AbstractManoptOptimizer end
1414

1515
SciMLBase.supports_opt_cache_interface(opt::AbstractManoptOptimizer) = true
16+
SciMLBase.requiresgradient(opt::Union{GradientDescentOptimizer, ConjugateGradientDescentOptimizer, QuasiNewtonOptimizer, ConvexBundleOptimizer, FrankWolfeOptimizer}) = true
17+
SciMLBase.requireshessian(opt::Union{AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer}) = true
18+
1619

1720
function __map_optimizer_args!(cache::OptimizationCache,
1821
opt::AbstractManoptOptimizer;

lib/OptimizationOptimJL/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using OptimizationOptimJL,
2-
OptimizationOptimJL.Optim, Optimization, ForwardDiff, Zygote,
2+
OptimizationOptimJL.Optim, Optimization, ForwardDiff, Zygote, ReverseDiff.
33
Random, ModelingToolkit, Optimization.OptimizationBase.DifferentiationInterface
44
using Test
55

src/sophia.jl

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ struct Sophia
1010
end
1111

1212
SciMLBase.supports_opt_cache_interface(opt::Sophia) = true
13+
SciMLBase.requiresgradient(opt::Sophia) = true
14+
SciMLBase.allowsfg(opt::Sophia) = true
15+
SciMLBase.requireshessian(opt::Sophia) = true
1316

1417
function Sophia(; η = 1e-3, βs = (0.9, 0.999), ϵ = 1e-8, λ = 1e-1, k = 10,
1518
ρ = 0.04)
@@ -18,11 +21,10 @@ end
1821

1922
clip(z, ρ) = max(min(z, ρ), -ρ)
2023

21-
function SciMLBase.__init(prob::OptimizationProblem, opt::Sophia,
22-
data = Optimization.DEFAULT_DATA;
24+
function SciMLBase.__init(prob::OptimizationProblem, opt::Sophia;
2325
maxiters::Number = 1000, callback = (args...) -> (false),
2426
progress = false, save_best = true, kwargs...)
25-
return OptimizationCache(prob, opt, data; maxiters, callback, progress,
27+
return OptimizationCache(prob, opt; maxiters, callback, progress,
2628
save_best, kwargs...)
2729
end
2830

@@ -60,46 +62,46 @@ function SciMLBase.__solve(cache::OptimizationCache{
6062
λ = uType(cache.opt.λ)
6163
ρ = uType(cache.opt.ρ)
6264

63-
if cache.data != Optimization.DEFAULT_DATA
64-
maxiters = length(cache.data)
65-
data = cache.data
65+
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
66+
67+
if cache.p == SciMLBase.NullParameters()
68+
data = OptimizationBase.DEFAULT_DATA
6669
else
67-
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
68-
data = Optimization.take(cache.data, maxiters)
70+
data = cache.p
6971
end
7072

71-
maxiters = Optimization._check_and_convert_maxiters(maxiters)
72-
7373
f = cache.f
7474
θ = copy(cache.u0)
7575
gₜ = zero(θ)
7676
mₜ = zero(θ)
7777
hₜ = zero(θ)
78-
for (i, d) in enumerate(data)
79-
f.grad(gₜ, θ, d...)
80-
x = cache.f(θ, cache.p, d...)
81-
opt_state = Optimization.OptimizationState(; iter = i,
82-
u = θ,
83-
objective = first(x),
84-
grad = gₜ,
85-
original = nothing)
86-
cb_call = cache.callback(θ, x...)
87-
if !(cb_call isa Bool)
88-
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
89-
elseif cb_call
90-
break
91-
end
92-
mₜ = βs[1] .* mₜ + (1 - βs[1]) .* gₜ
78+
for _ in 1:maxiters
79+
for (i, d) in enumerate(data)
80+
f.grad(gₜ, θ, d)
81+
x = cache.f(θ, cache.p, d...)
82+
opt_state = Optimization.OptimizationState(; iter = i,
83+
u = θ,
84+
objective = first(x),
85+
grad = gₜ,
86+
original = nothing)
87+
cb_call = cache.callback(θ, x...)
88+
if !(cb_call isa Bool)
89+
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
90+
elseif cb_call
91+
break
92+
end
93+
mₜ = βs[1] .* mₜ + (1 - βs[1]) .* gₜ
9394

94-
if i % cache.opt.k == 1
95-
hₜ₋₁ = copy(hₜ)
96-
u = randn(uType, length(θ))
97-
f.hv(hₜ, θ, u, d...)
98-
hₜ = βs[2] .* hₜ₋₁ + (1 - βs[2]) .* (u .* hₜ)
95+
if i % cache.opt.k == 1
96+
hₜ₋₁ = copy(hₜ)
97+
u = randn(uType, length(θ))
98+
f.hv(hₜ, θ, u, d)
99+
hₜ = βs[2] .* hₜ₋₁ + (1 - βs[2]) .* (u .* hₜ)
100+
end
101+
θ = θ .- η * λ .* θ
102+
θ = θ .-
103+
η .* clip.(mₜ ./ max.(hₜ, Ref(ϵ)), Ref(ρ))
99104
end
100-
θ = θ .- η * λ .* θ
101-
θ = θ .-
102-
η .* clip.(mₜ ./ max.(hₜ, Ref(ϵ)), Ref(ρ))
103105
end
104106

105107
return SciMLBase.build_solution(cache, cache.opt,

test/ADtests.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ end
3030
sol = solve(prob, Optim.Newton())
3131
@test 10 * sol.objective < l1
3232
@test sol.retcode == ReturnCode.Success
33-
34-
sol = Optimization.solve(prob,
35-
Optimization.Sophia(; η = 0.5,
36-
λ = 0.0),
37-
maxiters = 1000)
38-
@test 10 * sol.objective < l1
3933
end
4034

4135
@testset "No constraint" begin

test/minibatch.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ optfun = OptimizationFunction(loss_adjoint,
5858
Optimization.AutoZygote())
5959
optprob = OptimizationProblem(optfun, pp, train_loader)
6060

61-
res1 = Optimization.solve(optprob, Optimisers.Adam(0.05),
62-
callback = callback, maxiters = numEpochs)
61+
sol = Optimization.solve(optprob,
62+
Optimization.Sophia(; η = 0.5,
63+
λ = 0.0),
64+
maxiters = 1000)
6365
@test 10res1.objective < l1
6466

6567
optfun = OptimizationFunction(loss_adjoint,

0 commit comments

Comments
 (0)