Skip to content

Commit 354fb0e

Browse files
committed
Add support for Jacobian-vector products
1 parent 52f854c commit 354fb0e

File tree

11 files changed

+602
-136
lines changed

11 files changed

+602
-136
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ NLSolversBase.jl is the core, common dependency of several packages in the [Juli
1919
The package aims at establishing common ground for [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [LineSearches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl), and [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl). The common ground is mainly the types used to hold objective related callables, information about the objectives, and an interface to interact with these types.
2020

2121
## NDifferentiable
22-
There are currently three main types: `NonDifferentiable`, `OnceDifferentiable`, and `TwiceDifferentiable`. There's also a more experimental `TwiceDifferentiableHV` for optimization algorithms that use Hessian-vector products. An `NDifferentiable` instance can be used to hold relevant functions for
22+
There are currently three main types: `NonDifferentiable`, `OnceDifferentiable`, and `TwiceDifferentiable`. An `NDifferentiable` instance can be used to hold relevant functions for
2323

2424
- Optimization: $F : \mathbb{R}^N \to \mathbb{R}$
2525
- Solving systems of equations: $F : \mathbb{R}^N \to \mathbb{R}^N$

src/NLSolversBase.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@ export AbstractObjective,
99
NonDifferentiable,
1010
OnceDifferentiable,
1111
TwiceDifferentiable,
12-
TwiceDifferentiableHV,
1312
value,
1413
value!,
1514
value_gradient!,
1615
value_jacobian!,
16+
value_jvp!,
1717
gradient,
1818
gradient!,
1919
jacobian,
2020
jacobian!,
21+
jvp,
22+
jvp!,
2123
hessian,
2224
hessian!,
2325
value!!,
@@ -38,6 +40,7 @@ export AbstractObjective,
3840
clear!,
3941
f_calls,
4042
g_calls,
43+
jvp_calls,
4144
h_calls,
4245
hv_calls
4346

@@ -51,12 +54,11 @@ include("objective_types/abstract.jl")
5154
include("objective_types/nondifferentiable.jl")
5255
include("objective_types/oncedifferentiable.jl")
5356
include("objective_types/twicedifferentiable.jl")
54-
include("objective_types/twicedifferentiablehv.jl")
5557
include("objective_types/incomplete.jl")
5658
include("objective_types/constraints.jl")
5759
include("interface.jl")
5860

5961
NonDifferentiable(f::OnceDifferentiable, x::AbstractArray) = NonDifferentiable(f.f, x, copy(f.F))
6062
NonDifferentiable(f::TwiceDifferentiable, x::AbstractArray) = NonDifferentiable(f.f, x, copy(f.F))
61-
NonDifferentiable(f::TwiceDifferentiableHV, x::AbstractArray) = NonDifferentiable(f.f, x, copy(f.F))
63+
6264
end # module

src/interface.jl

Lines changed: 152 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,36 @@ value(obj::AbstractObjective) = obj.F
103103
gradient(obj::AbstractObjective) = obj.DF
104104
"Get the most recently evaluated Jacobian of `obj`."
105105
jacobian(obj::AbstractObjective) = obj.DF
106+
"""
107+
jvp(obj::Union{OnceDifferentiable,TwiceDifferentiable})
108+
109+
Return the most recently evaluated Jacobian-vector product of `obj`.
110+
111+
!!! warn
112+
Generally, it is unsafe to rely on the state of `obj`.
113+
This function should only be used to, e.g., check whether the
114+
most recently evaluated Jacobian-vector product is finite.
115+
In most cases one should use [`jvp!(obj, x, v)`](@ref)
116+
instead of this function.
117+
"""
118+
jvp(obj::Union{OnceDifferentiable, TwiceDifferentiable}) = obj.JVP
106119
"Get the `i`th element of the most recently evaluated gradient of `obj`."
107120
gradient(obj::AbstractObjective, i::Integer) = obj.DF[i]
108121
"Get the most recently evaluated Hessian of `obj`"
109122
hessian(obj::AbstractObjective) = obj.H
123+
"""
124+
hv_product(obj::TwiceDifferentiable)
125+
126+
Return the most recently evaluated Hessian-vector product of `obj`.
127+
128+
!!! warn
129+
Generally, it is unsafe to rely on the state of `obj`.
130+
This function should only be used to, e.g., check whether the
131+
most recently evaluated Hessian-vector product is finite.
132+
In most cases one should use [`hv_product!(obj, x, v)`](@ref)
133+
instead of this function.
134+
"""
135+
hv_product(obj::TwiceDifferentiable) = obj.Hv
110136

111137
value_jacobian!(obj, x) = value_jacobian!(obj, obj.F, obj.DF, x)
112138
function value_jacobian!(obj, F, J, x)
@@ -153,6 +179,92 @@ function jacobian(obj::AbstractObjective, x)
153179
return newdf
154180
end
155181

182+
"""
183+
jvp!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
184+
185+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
186+
and cache the results in `obj`.
187+
188+
!!! note
189+
This function does use cached results if available.
190+
"""
191+
function jvp!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
192+
if x != obj.x_jvp || v != obj.v_jvp
193+
jvp!!(obj, x, v)
194+
end
195+
return obj.JVP
196+
end
197+
198+
"""
199+
jvp!!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
200+
201+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
202+
and cache the results in `obj`.
203+
204+
!!! note
205+
This function does not use cached results but forces reevaluation of the Jacobian-vector product.
206+
"""
207+
function jvp!!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
208+
if obj.JVP isa Real
209+
obj.JVP = obj.jvp(x, v)
210+
else
211+
obj.jvp(obj.JVP, x, v)
212+
end
213+
copyto!(obj.x_jvp, x)
214+
copyto!(obj.v_jvp, v)
215+
obj.jvp_calls += 1
216+
return obj.JVP
217+
end
218+
219+
"""
220+
value_jvp!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
221+
222+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
223+
and cache the results in `obj`.
224+
225+
!!! note
226+
This function does use cached results if available.
227+
"""
228+
function value_jvp!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
229+
if x != obj.x_f
230+
if x != obj.x_jvp || v != obj.v_jvp
231+
# Both value and Jacobian-vector product have to be recomputed
232+
value_jvp!!(obj, x, v)
233+
else
234+
# Only value has to be recomputed
235+
value!!(obj, x)
236+
end
237+
elseif x != obj.x_jvp || v != obj.v_jvp
238+
jvp!!(obj, x, v)
239+
end
240+
return obj.F, obj.JVP
241+
end
242+
243+
"""
244+
value_jvp!!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
245+
246+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
247+
and cache the results in `obj`.
248+
249+
!!! note
250+
This function does not use cached results but forces reevaluation of the Jacobian-vector product.
251+
"""
252+
function value_jvp!!(obj::Union{OnceDifferentiable,TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
253+
if obj.F isa Real
254+
y, ty = obj.fjvp(x, v)
255+
obj.F = y
256+
obj.JVP = ty
257+
else
258+
obj.fjvp(obj.F, obj.JVP, x, v)
259+
end
260+
copyto!(obj.x_f, x)
261+
copyto!(obj.x_jvp, x)
262+
copyto!(obj.v_jvp, v)
263+
obj.f_calls += 1
264+
obj.jvp_calls += 1
265+
return obj.F, obj.JVP
266+
end
267+
156268
value(obj::NonDifferentiable{TF, TX}, x) where {TF<:AbstractArray, TX} = value(obj, copy(obj.F), x)
157269
value(obj::OnceDifferentiable{TF, TDF, TX}, x) where {TF<:AbstractArray, TDF, TX} = value(obj, copy(obj.F), x)
158270
function value(obj::AbstractObjective, F, x)
@@ -170,6 +282,25 @@ function value!!(obj::AbstractObjective, F, x)
170282
F
171283
end
172284

285+
function hv_product!(obj::TwiceDifferentiable, x, v)
286+
if x != obj.x_hv || v != obj.v_hv
287+
hv_product!!(obj, x, v)
288+
end
289+
obj.Hv
290+
end
291+
function hv_product!!(obj::TwiceDifferentiable, x, v)
292+
if obj.hv === nothing
293+
H = hessian!(obj, x)
294+
LinearAlgebra.mul!(obj.Hv, H, v)
295+
else
296+
obj.hv_calls += 1
297+
obj.hv(obj.Hv, x, v)
298+
end
299+
copyto!(obj.x_hv, x)
300+
copyto!(obj.v_hv, v)
301+
return obj.Hv
302+
end
303+
173304
function _clear_f!(d::AbstractObjective)
174305
d.f_calls = 0
175306
if d.F isa AbstractArray
@@ -188,6 +319,18 @@ function _clear_df!(d::AbstractObjective)
188319
nothing
189320
end
190321

322+
function _clear_jvp!(d::AbstractObjective)
323+
d.jvp_calls = 0
324+
if d.JVP isa AbstractArray
325+
fill!(d.JVP, NaN)
326+
else
327+
d.JVP = NaN
328+
end
329+
fill!(d.x_jvp, NaN)
330+
fill!(d.v_jvp, NaN)
331+
nothing
332+
end
333+
191334
function _clear_h!(d::AbstractObjective)
192335
d.h_calls = 0
193336
fill!(d.H, NaN)
@@ -208,28 +351,25 @@ clear!(d::NonDifferentiable) = _clear_f!(d)
208351
function clear!(d::OnceDifferentiable)
209352
_clear_f!(d)
210353
_clear_df!(d)
354+
_clear_jvp!(d)
211355
nothing
212356
end
213357

214358
function clear!(d::TwiceDifferentiable)
215359
_clear_f!(d)
216360
_clear_df!(d)
361+
_clear_jvp!(d)
217362
_clear_h!(d)
218-
nothing
219-
end
220-
221-
function clear!(d::TwiceDifferentiableHV)
222-
_clear_f!(d)
223-
_clear_df!(d)
224363
_clear_hv!(d)
225364
nothing
226365
end
227366

367+
f_calls(d::Union{NonDifferentiable, OnceDifferentiable, TwiceDifferentiable}) = d.f_calls
228368
g_calls(d::NonDifferentiable) = 0
369+
g_calls(d::Union{OnceDifferentiable, TwiceDifferentiable}) = d.df_calls
370+
jvp_calls(d::NonDifferentiable) = 0
371+
jvp_calls(d::Union{OnceDifferentiable, TwiceDifferentiable}) = d.jvp_calls
229372
h_calls(d::Union{NonDifferentiable, OnceDifferentiable}) = 0
230-
f_calls(d) = d.f_calls
231-
g_calls(d) = d.df_calls
232-
h_calls(d) = d.h_calls
233-
hv_calls(d) = 0
234-
h_calls(d::TwiceDifferentiableHV) = 0
235-
hv_calls(d::TwiceDifferentiableHV) = d.hv_calls
373+
h_calls(d::TwiceDifferentiable) = d.h_calls
374+
hv_calls(d::Union{NonDifferentiable, OnceDifferentiable}) = 0
375+
hv_calls(d::TwiceDifferentiable) = d.hv_calls

src/objective_types/abstract.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,45 @@ function make_fdf(x, F::Number, f, g!)
1515
end
1616
end
1717

18+
# Given Julia functions of the gradient/Jacobian, create a function that calculates the Jacobian-vector product.
19+
function make_jvp(x::AbstractArray, F::AbstractArray, j!)
20+
let j! = j!, jx = alloc_DF(x, F)
21+
function jvp!(jvp, x, v)
22+
j!(jx, x)
23+
LinearAlgebra.mul!(jvp, jx, v)
24+
return jvp
25+
end
26+
end
27+
end
28+
function make_jvp(x::AbstractArray, F::Number, g!)
29+
let g! = g!, gx = alloc_DF(x, F)
30+
function jvp(x, v)
31+
g!(gx, x)
32+
return LinearAlgebra.dot(gx, v)
33+
end
34+
end
35+
end
36+
37+
# Given a Julia function of the objective function and its gradient/Jacobian, create a function that calculates the function value and the Jacobian-vector product.
38+
function make_fjvp(x::AbstractArray, F::AbstractArray, fj!)
39+
let fj! = fj!, jx = alloc_DF(x, F)
40+
function fjvp!(fx, jvp, x, v)
41+
fj!(fx, jx, x)
42+
LinearAlgebra.mul!(jvp, jx, v)
43+
return fx, jvp
44+
end
45+
end
46+
end
47+
function make_fjvp(x::AbstractArray, F::Number, fg!)
48+
let fg! = fg!, gx = alloc_DF(x, F)
49+
function jvp(x, v)
50+
fx = fg!(gx, x)
51+
return fx, LinearAlgebra.dot(gx, v)
52+
end
53+
end
54+
end
55+
56+
1857
# Initialize an n-by-n Jacobian
1958
alloc_DF(x, F) = eltype(x)(NaN) .* vec(F) .* vec(x)'
2059

0 commit comments

Comments
 (0)