Skip to content

Commit e601dc9

Browse files
committed
Compute JVP in line searches
1 parent d4bf817 commit e601dc9

File tree

12 files changed

+65
-47
lines changed

12 files changed

+65
-47
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ jobs:
1111
strategy:
1212
matrix:
1313
version:
14-
- "min"
15-
- "lts"
14+
# - "min"
15+
# - "lts"
1616
- "1"
1717
os:
1818
- ubuntu-latest

.github/workflows/Docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
with:
2020
version: '1'
2121
- name: Install dependencies
22-
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
22+
run: julia --project=docs/ -e 'using Pkg; Pkg.instantiate()'
2323
- name: Build and deploy
2424
env:
2525
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ LineSearches = "7.4.0"
3434
LinearAlgebra = "<0.0.1, 1.6"
3535
MathOptInterface = "1.17"
3636
Measurements = "2.14.1"
37-
NLSolversBase = "7.9.0"
37+
NLSolversBase = "8"
3838
NaNMath = "0.3.2, 1"
3939
OptimTestProblems = "2.0.3"
4040
PositiveFactorizations = "0.2.2"
@@ -65,3 +65,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6565

6666
[targets]
6767
test = ["Test", "Aqua", "Distributions", "ExplicitImports", "ForwardDiff", "JET", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "ReverseDiff"]
68+
69+
[sources]
70+
LineSearches = { url = "https://github.com/devmotion/LineSearches.jl.git", rev = "dmw/jvp" }
71+
NLSolversBase = { url = "https://github.com/devmotion/NLSolversBase.jl.git", rev = "dmw/jvp" }

docs/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1212
Documenter = "1"
1313
Literate = "2"
1414

15-
[sources.Optim]
16-
path = ".."
15+
[sources]
16+
Optim = { path = ".." }
17+
NLSolversBase = { url = "https://github.com/devmotion/NLSolversBase.jl.git", rev = "dmw/jvp" }

docs/src/examples/ipnewton_basics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# constraint is unbounded from below or above respectively.
2323

2424
using Optim, NLSolversBase #hide
25+
import ADTypes #hide
2526
import NLSolversBase: clear! #hide
2627

2728
# # Constrained optimization with `IPNewton`

ext/OptimMOIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module OptimMOIExt
22

33
using Optim
4-
using Optim.LinearAlgebra: rmul!
4+
using Optim.LinearAlgebra: rmul!
55
import MathOptInterface as MOI
66

77
function __init__()

src/Manifolds.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ project_tangent(M::Manifold, x) = project_tangent!(M, similar(x), x)
2020
retract(M::Manifold, x) = retract!(M, copy(x))
2121

2222
# Fake objective function implementing a retraction
23-
mutable struct ManifoldObjective{T<:NLSolversBase.AbstractObjective} <:
23+
mutable struct ManifoldObjective{M<:Manifold,T<:NLSolversBase.AbstractObjective} <:
2424
NLSolversBase.AbstractObjective
25-
manifold::Manifold
25+
manifold::M
2626
inner_obj::T
2727
end
2828
# TODO: is it safe here to call retract! and change x?
@@ -52,6 +52,20 @@ function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
5252
return value(obj.inner_obj)
5353
end
5454

55+
# In general, we have to compute the gradient/Jacobian separately as it has to be projected
56+
function NLSolversBase.jvp!(obj::ManifoldObjective, x, v)
57+
xin = retract(obj.manifold, x)
58+
gradient!(obj.inner_obj, xin)
59+
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
60+
return dot(gradient(obj.inner_obj), v)
61+
end
62+
function NLSolversBase.value_jvp!(obj::ManifoldObjective, x, v)
63+
xin = retract(obj.manifold, x)
64+
value_gradient!(obj.inner_obj, xin)
65+
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
66+
return value(obj.inner_obj), dot(gradient(obj.inner_obj), v)
67+
end
68+
5569
"""Flat Euclidean space {R,C}^N, with projections equal to the identity."""
5670
struct Flat <: Manifold end
5771
# all the functions below are no-ops, and therefore the generated code
@@ -62,6 +76,10 @@ retract!(M::Flat, x) = x
6276
project_tangent(M::Flat, g, x) = g
6377
project_tangent!(M::Flat, g, x) = g
6478

79+
# Optimizations for `Flat` manifold
80+
NLSolversBase.jvp!(obj::ManifoldObjective{Flat}, x, v) = jvp!(obj.inner_obj, x, v)
81+
NLSolversBase.value_jvp!(obj::ManifoldObjective{Flat}, x, v) = value_jvp!(obj.inner_obj, x, v)
82+
6583
"""Spherical manifold {|x| = 1}."""
6684
struct Sphere <: Manifold end
6785
retract!(S::Sphere, x) = (x ./= norm(x))

src/Optim.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ documentation online at http://julianlsolvers.github.io/Optim.jl/stable/ .
1616
"""
1717
module Optim
1818

19+
import ADTypes
20+
1921
using PositiveFactorizations: Positive # for globalization strategy in Newton
2022

2123
using LineSearches: LineSearches # for globalization strategy in Quasi-Newton algs
@@ -41,6 +43,15 @@ using NLSolversBase:
4143
TwiceDifferentiableConstraints,
4244
nconstraints,
4345
nconstraints_x,
46+
value,
47+
value!,
48+
value!!,
49+
gradient,
50+
gradient!,
51+
value_gradient!,
52+
value_gradient!!,
53+
jvp!,
54+
value_jvp!,
4455
hessian,
4556
hessian!,
4657
hessian!!,

src/multivariate/optimize/interface.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,6 @@ promote_objtype(
6666
inplace::Bool,
6767
f::InplaceObjective,
6868
) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
69-
promote_objtype(
70-
method::SecondOrderOptimizer,
71-
x,
72-
autodiff::ADTypes.AbstractADType,
73-
inplace::Bool,
74-
f::NLSolversBase.InPlaceObjectiveFGHv,
75-
) = TwiceDifferentiableHV(f, x)
76-
promote_objtype(
77-
method::SecondOrderOptimizer,
78-
x,
79-
autodiff::ADTypes.AbstractADType,
80-
inplace::Bool,
81-
f::NLSolversBase.InPlaceObjectiveFG_Hv,
82-
) = TwiceDifferentiableHV(f, x)
8369
promote_objtype(method::SecondOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, f, g) =
8470
TwiceDifferentiable(
8571
f,

src/multivariate/solvers/constrained/fminbox.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import NLSolversBase:
2-
value, value!, value!!, gradient, gradient!, value_gradient!, value_gradient!!
31
####### FIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIX THE MIDDLE OF BOX CASE THAT WAS THERE
42
mutable struct BarrierWrapper{TO,TB,Tm,TF,TDF} <: AbstractObjective
53
obj::TO
@@ -64,24 +62,24 @@ function _barrier_term_gradient(x::T, l, u) where {T}
6462
end
6563
return g
6664
end
67-
function value_gradient!(bb::BoxBarrier, g, x)
65+
function NLSolversBase.value_gradient!(bb::BoxBarrier, g, x)
6866
g .= _barrier_term_gradient.(x, bb.lower, bb.upper)
6967
value(bb, x)
7068
end
71-
function gradient(bb::BoxBarrier, g, x)
69+
function NLSolversBase.gradient(bb::BoxBarrier, g, x)
7270
g = copy(g)
7371
g .= _barrier_term_gradient.(x, bb.lower, bb.upper)
7472
end
7573
# Wrappers
76-
function value!!(bw::BarrierWrapper, x)
74+
function NLSolversBase.value!!(bw::BarrierWrapper, x)
7775
bw.Fb = value(bw.b, x)
7876
bw.Ftotal = bw.mu * bw.Fb
7977
if in_box(bw, x)
8078
value!!(bw.obj, x)
8179
bw.Ftotal += value(bw.obj)
8280
end
8381
end
84-
function value_gradient!!(bw::BarrierWrapper, x)
82+
function NLSolversBase.value_gradient!!(bw::BarrierWrapper, x)
8583
bw.Fb = value(bw.b, x)
8684
bw.Ftotal = bw.mu * bw.Fb
8785
bw.DFb .= _barrier_term_gradient.(x, bw.b.lower, bw.b.upper)
@@ -93,7 +91,7 @@ function value_gradient!!(bw::BarrierWrapper, x)
9391
end
9492

9593
end
96-
function value_gradient!(bb::BarrierWrapper, x)
94+
function NLSolversBase.value_gradient!(bb::BarrierWrapper, x)
9795
bb.DFb .= _barrier_term_gradient.(x, bb.b.lower, bb.b.upper)
9896
bb.Fb = value(bb.b, x)
9997
bb.DFtotal .= bb.mu .* bb.DFb
@@ -105,9 +103,9 @@ function value_gradient!(bb::BarrierWrapper, x)
105103
bb.Ftotal += value(bb.obj)
106104
end
107105
end
108-
value(bb::BoxBarrier, x) =
106+
NLSolversBase.value(bb::BoxBarrier, x) =
109107
mapreduce(x -> _barrier_term_value(x...), +, zip(x, bb.lower, bb.upper))
110-
function value!(obj::BarrierWrapper, x)
108+
function NLSolversBase.value!(obj::BarrierWrapper, x)
111109
obj.Fb = value(obj.b, x)
112110
obj.Ftotal = obj.mu * obj.Fb
113111
if in_box(obj, x)
@@ -116,20 +114,20 @@ function value!(obj::BarrierWrapper, x)
116114
end
117115
obj.Ftotal
118116
end
119-
value(obj::BarrierWrapper) = obj.Ftotal
120-
function value(obj::BarrierWrapper, x)
117+
NLSolversBase.value(obj::BarrierWrapper) = obj.Ftotal
118+
function NLSolversBase.value(obj::BarrierWrapper, x)
121119
F = obj.mu * value(obj.b, x)
122120
if in_box(obj, x)
123121
F += value(obj.obj, x)
124122
end
125123
F
126124
end
127-
function gradient!(obj::BarrierWrapper, x)
125+
function NLSolversBase.gradient!(obj::BarrierWrapper, x)
128126
gradient!(obj.obj, x)
129127
obj.DFb .= gradient(obj.b, obj.DFb, x) # this should just be inplace?
130128
obj.DFtotal .= gradient(obj.obj) .+ obj.mu * obj.Fb
131129
end
132-
gradient(obj::BarrierWrapper) = obj.DFtotal
130+
NLSolversBase.gradient(obj::BarrierWrapper) = obj.DFtotal
133131

134132
# this mutates mu but not the gradients
135133
# Super unsafe in that it depends on x_df being correct!
@@ -299,7 +297,7 @@ function optimize(
299297
initial_x::AbstractArray,
300298
F::Fminbox = Fminbox(),
301299
options::Options = Options();
302-
inplace = true,
300+
inplace::Bool = true,
303301
)
304302

305303
g! = inplace ? g : (G, x) -> copyto!(G, g(x))
@@ -536,12 +534,13 @@ function optimize(
536534
end
537535
results = optimize(dfbox, x, _optimizer, options, state)
538536
stopped_by_callback = results.stopped_by.callback
539-
dfbox.obj.f_calls[1] = 0
537+
# TODO: Define an API (e.g. `reset_calls!`?) in NLSolversBase
538+
dfbox.obj.f_calls = 0
540539
if hasfield(typeof(dfbox.obj), :df_calls)
541-
dfbox.obj.df_calls[1] = 0
540+
dfbox.obj.df_calls = 0
542541
end
543542
if hasfield(typeof(dfbox.obj), :h_calls)
544-
dfbox.obj.h_calls[1] = 0
543+
dfbox.obj.h_calls = 0
545544
end
546545
copyto!(x, minimizer(results))
547546
boxdist = Base.minimum(((xi, li, ui),) -> min(xi - li, ui - xi), zip(x, l, u)) # Base.minimum !== minimum
@@ -613,12 +612,13 @@ function optimize(
613612
resultsnew = optimize(dfbox, x, _optimizer, options, state)
614613
stopped_by_callback = resultsnew.stopped_by.callback
615614
append!(results, resultsnew)
616-
dfbox.obj.f_calls[1] = 0
615+
# TODO: Define an API (e.g. `reset_calls!`?) in NLSolversBase
616+
dfbox.obj.f_calls = 0
617617
if hasfield(typeof(dfbox.obj), :df_calls)
618-
dfbox.obj.df_calls[1] = 0
618+
dfbox.obj.df_calls = 0
619619
end
620620
if hasfield(typeof(dfbox.obj), :h_calls)
621-
dfbox.obj.h_calls[1] = 0
621+
dfbox.obj.h_calls = 0
622622
end
623623
copyto!(x, minimizer(results))
624624
boxdist = Base.minimum(((xi, li, ui),) -> min(xi - li, ui - xi), zip(x, l, u)) # Base.minimum !== minimum

0 commit comments

Comments
 (0)