Skip to content

Commit 536eaa9

Browse files
committed
Compute JVP in line searches
1 parent 540c97b commit 536eaa9

40 files changed

+310
-238
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ jobs:
1111
strategy:
1212
matrix:
1313
version:
14-
- "min"
15-
- "lts"
14+
# - "min"
15+
# - "lts"
1616
- "1"
1717
os:
1818
- ubuntu-latest

.github/workflows/Docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
with:
2020
version: '1'
2121
- name: Install dependencies
22-
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
22+
run: julia --project=docs/ -e 'using Pkg; Pkg.instantiate()'
2323
- name: Build and deploy
2424
env:
2525
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1818
[weakdeps]
1919
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
2020

21+
[sources]
22+
LineSearches = { url = "https://github.com/devmotion/LineSearches.jl.git", rev = "dmw/jvp" }
23+
NLSolversBase = { url = "https://github.com/devmotion/NLSolversBase.jl.git", rev = "dmw/jvp" }
24+
2125
[extensions]
2226
OptimMOIExt = "MathOptInterface"
2327

@@ -30,11 +34,11 @@ ExplicitImports = "1.13.2"
3034
FillArrays = "0.6.2, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
3135
ForwardDiff = "0.10, 1"
3236
JET = "0.9, 0.10"
33-
LineSearches = "7.5.1"
37+
LineSearches = "7.6"
3438
LinearAlgebra = "<0.0.1, 1.6"
3539
MathOptInterface = "1.17"
3640
Measurements = "2.14.1"
37-
NLSolversBase = "7.9.0"
41+
NLSolversBase = "8"
3842
NaNMath = "0.3.2, 1"
3943
OptimTestProblems = "2.0.3"
4044
PositiveFactorizations = "0.2.2"

docs/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1212
Documenter = "1"
1313
Literate = "2"
1414

15-
[sources.Optim]
16-
path = ".."
15+
[sources]
16+
Optim = { path = ".." }
17+
NLSolversBase = { url = "https://github.com/devmotion/NLSolversBase.jl.git", rev = "dmw/jvp" }

docs/src/examples/ipnewton_basics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# constraint is unbounded from below or above respectively.
2323

2424
using Optim, NLSolversBase #hide
25+
import ADTypes #hide
2526
import NLSolversBase: clear! #hide
2627

2728
# # Constrained optimization with `IPNewton`

ext/OptimMOIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module OptimMOIExt
22

33
using Optim
4-
using Optim.LinearAlgebra: rmul!
4+
using Optim.LinearAlgebra: rmul!
55
import MathOptInterface as MOI
66

77
function __init__()

src/Manifolds.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ project_tangent(M::Manifold, x) = project_tangent!(M, similar(x), x)
2020
retract(M::Manifold, x) = retract!(M, copy(x))
2121

2222
# Fake objective function implementing a retraction
23-
mutable struct ManifoldObjective{T<:NLSolversBase.AbstractObjective} <:
24-
NLSolversBase.AbstractObjective
25-
manifold::Manifold
23+
struct ManifoldObjective{M<:Manifold,T<:AbstractObjective} <: AbstractObjective
24+
manifold::M
2625
inner_obj::T
2726
end
2827
# TODO: is it safe here to call retract! and change x?
@@ -43,6 +42,20 @@ function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
4342
return f_x, g_x
4443
end
4544

45+
# In general, we have to compute the gradient/Jacobian separately as it has to be projected
46+
function NLSolversBase.jvp!(obj::ManifoldObjective, x, v)
47+
xin = retract(obj.manifold, x)
48+
g_x = gradient!(obj.inner_obj, xin)
49+
project_tangent!(obj.manifold, g_x, xin)
50+
return dot(g_x, v)
51+
end
52+
function NLSolversBase.value_jvp!(obj::ManifoldObjective, x, v)
53+
xin = retract(obj.manifold, x)
54+
f_x, g_x = value_gradient!(obj.inner_obj, xin)
55+
project_tangent!(obj.manifold, g_x, xin)
56+
return f_x, dot(g_x, v)
57+
end
58+
4659
"""Flat Euclidean space {R,C}^N, with projections equal to the identity."""
4760
struct Flat <: Manifold end
4861
# all the functions below are no-ops, and therefore the generated code
@@ -53,6 +66,10 @@ retract!(M::Flat, x) = x
5366
project_tangent(M::Flat, g, x) = g
5467
project_tangent!(M::Flat, g, x) = g
5568

69+
# Optimizations for `Flat` manifold
70+
NLSolversBase.jvp!(obj::ManifoldObjective{Flat}, x, v) = NLSolversBase.jvp!(obj.inner_obj, x, v)
71+
NLSolversBase.value_jvp!(obj::ManifoldObjective{Flat}, x, v) = NLSolversBase.value_jvp!(obj.inner_obj, x, v)
72+
5673
"""Spherical manifold {|x| = 1}."""
5774
struct Sphere <: Manifold end
5875
retract!(S::Sphere, x) = (x ./= norm(x))

src/Optim.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ documentation online at http://julianlsolvers.github.io/Optim.jl/stable/ .
1616
"""
1717
module Optim
1818

19+
import ADTypes
20+
1921
using PositiveFactorizations: Positive # for globalization strategy in Newton
2022

2123
using LineSearches: LineSearches # for globalization strategy in Quasi-Newton algs
@@ -35,14 +37,13 @@ using NLSolversBase:
3537
NonDifferentiable,
3638
OnceDifferentiable,
3739
TwiceDifferentiable,
38-
TwiceDifferentiableHV,
3940
AbstractConstraints,
4041
ConstraintBounds,
4142
TwiceDifferentiableConstraints,
4243
nconstraints,
4344
nconstraints_x,
4445
hessian!,
45-
hv_product!
46+
hvp!
4647

4748
# var for NelderMead
4849
using Statistics: var

src/api.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,16 @@ g_norm_trace(r::OptimizationResults) =
9999
error("g_norm_trace is not implemented for $(summary(r)).")
100100
g_norm_trace(r::MultivariateOptimizationResults) = [state.g_norm for state in trace(r)]
101101

102+
# TODO: Overload `NLSolversBase.xxx` instead of defining separate `Optim.xxx` methods?
102103
f_calls(r::OptimizationResults) = r.f_calls
103-
f_calls(d::AbstractObjective) = NLSolversBase.f_calls(d)
104-
105-
g_calls(r::OptimizationResults) = error("g_calls is not implemented for $(summary(r)).")
104+
g_calls(r::OptimizationResults) = error(LazyString("`g_calls` is not implemented for ", summary(r), "."))
106105
g_calls(r::MultivariateOptimizationResults) = r.g_calls
107-
g_calls(d::AbstractObjective) = NLSolversBase.g_calls(d)
108-
109-
h_calls(r::OptimizationResults) = error("h_calls is not implemented for $(summary(r)).")
106+
jvp_calls(r::OptimizationResults) = error(LazyString("`jvp_calls` is not implemented for ", summary(r), "."))
107+
jvp_calls(r::MultivariateOptimizationResults) = r.jvp_calls
108+
h_calls(r::OptimizationResults) = error(LazyString("`h_calls` is not implemented for ", summary(r), "."))
110109
h_calls(r::MultivariateOptimizationResults) = r.h_calls
111-
h_calls(d::AbstractObjective) = NLSolversBase.h_calls(d) + NLSolversBase.hv_calls(d)
110+
hvp_calls(r::OptimizationResults) = error(LazyString("`hvp_calls` is not implemented for ", summary(r), "."))
111+
hvp_calls(r::MultivariateOptimizationResults) = r.hvp_calls
112112

113113
converged(r::UnivariateOptimizationResults) = r.stopped_by.converged
114114
function converged(r::MultivariateOptimizationResults)

src/maximize.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ for api_method in (
109109
:iteration_limit_reached,
110110
:f_calls,
111111
:g_calls,
112+
:jvp_calls,
112113
:h_calls,
114+
:hvp_calls,
113115
)
114116
@eval $api_method(r::MaximizationWrapper) = $api_method(res(r))
115117
end

0 commit comments

Comments
 (0)