Skip to content

Commit 23812df

Browse files
committed
Add support for Jacobian-vector products
1 parent 52f854c commit 23812df

File tree

11 files changed

+535
-126
lines changed

11 files changed

+535
-126
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ NLSolversBase.jl is the core, common dependency of several packages in the [Juli
1919
The package aims at establishing common ground for [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [LineSearches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl), and [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl). The common ground is mainly the types used to hold objective related callables, information about the objectives, and an interface to interact with these types.
2020

2121
## NDifferentiable
22-
There are currently three main types: `NonDifferentiable`, `OnceDifferentiable`, and `TwiceDifferentiable`. There's also a more experimental `TwiceDifferentiableHV` for optimization algorithms that use Hessian-vector products. An `NDifferentiable` instance can be used to hold relevant functions for
22+
There are currently three main types: `NonDifferentiable`, `OnceDifferentiable`, and `TwiceDifferentiable`. An `NDifferentiable` instance can be used to hold relevant functions for
2323

2424
- Optimization: $F : \mathbb{R}^N \to \mathbb{R}$
2525
- Solving systems of equations: $F : \mathbb{R}^N \to \mathbb{R}^N$

src/NLSolversBase.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,24 @@ export AbstractObjective,
99
NonDifferentiable,
1010
OnceDifferentiable,
1111
TwiceDifferentiable,
12-
TwiceDifferentiableHV,
1312
value,
1413
value!,
1514
value_gradient!,
1615
value_jacobian!,
16+
value_jvp!,
1717
gradient,
1818
gradient!,
1919
jacobian,
2020
jacobian!,
21+
jvp,
22+
jvp!,
2123
hessian,
2224
hessian!,
2325
value!!,
2426
value_gradient!!,
2527
value_jacobian!!,
28+
value_jvp!!,
2629
hessian!!,
27-
hv_product,
2830
hv_product!,
2931
only_fg!,
3032
only_fgh!,
@@ -38,6 +40,7 @@ export AbstractObjective,
3840
clear!,
3941
f_calls,
4042
g_calls,
43+
jvp_calls,
4144
h_calls,
4245
hv_calls
4346

@@ -51,12 +54,11 @@ include("objective_types/abstract.jl")
5154
include("objective_types/nondifferentiable.jl")
5255
include("objective_types/oncedifferentiable.jl")
5356
include("objective_types/twicedifferentiable.jl")
54-
include("objective_types/twicedifferentiablehv.jl")
5557
include("objective_types/incomplete.jl")
5658
include("objective_types/constraints.jl")
5759
include("interface.jl")
5860

5961
NonDifferentiable(f::OnceDifferentiable, x::AbstractArray) = NonDifferentiable(f.f, x, copy(f.F))
6062
NonDifferentiable(f::TwiceDifferentiable, x::AbstractArray) = NonDifferentiable(f.f, x, copy(f.F))
61-
NonDifferentiable(f::TwiceDifferentiableHV, x::AbstractArray) = NonDifferentiable(f.f, x, copy(f.F))
63+
6264
end # module

src/interface.jl

Lines changed: 120 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,92 @@ function jacobian(obj::AbstractObjective, x)
153153
return newdf
154154
end
155155

156+
"""
157+
jvp!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
158+
159+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
160+
and cache the results in `obj`.
161+
162+
!!! note
163+
This function does use cached results if available.
164+
"""
165+
function jvp!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
166+
if x != obj.x_jvp || v != obj.v_jvp
167+
jvp!!(obj, x, v)
168+
end
169+
return obj.JVP
170+
end
171+
172+
"""
173+
jvp!!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
174+
175+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
176+
and cache the results in `obj`.
177+
178+
!!! note
179+
This function does not use cached results but forces reevaluation of the Jacobian-vector product.
180+
"""
181+
function jvp!!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
182+
if obj.JVP isa Real
183+
obj.JVP = obj.jvp(x, v)
184+
else
185+
obj.jvp(obj.JVP, x, v)
186+
end
187+
copyto!(obj.x_jvp, x)
188+
copyto!(obj.v_jvp, v)
189+
obj.jvp_calls += 1
190+
return obj.JVP
191+
end
192+
193+
"""
194+
value_jvp!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
195+
196+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
197+
and cache the results in `obj`.
198+
199+
!!! note
200+
This function does use cached results if available.
201+
"""
202+
function value_jvp!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
203+
if x != obj.x_f
204+
if x != obj.x_jvp || v != obj.v_jvp
205+
# Both value and Jacobian-vector product have to be recomputed
206+
value_jvp!!(obj, x, v)
207+
else
208+
# Only value has to be recomputed
209+
value!!(obj, x)
210+
end
211+
elseif x != obj.x_jvp || v != obj.v_jvp
212+
jvp!!(obj, x, v)
213+
end
214+
return obj.F, obj.JVP
215+
end
216+
217+
"""
218+
value_jvp!!(obj::Union{OnceDifferentiable, TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
219+
220+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
221+
and cache the results in `obj`.
222+
223+
!!! note
224+
This function does not use cached results but forces reevaluation of the Jacobian-vector product.
225+
"""
226+
function value_jvp!!(obj::Union{OnceDifferentiable,TwiceDifferentiable}, x::AbstractArray, v::AbstractArray)
227+
if obj.F isa Real
228+
y, ty = obj.fjvp(x, v)
229+
obj.F = y
230+
obj.JVP = ty
231+
else
232+
obj.fjvp(obj.F, obj.JVP, x, v)
233+
end
234+
copyto!(obj.x_f, x)
235+
copyto!(obj.x_jvp, x)
236+
copyto!(obj.v_jvp, v)
237+
obj.f_calls += 1
238+
obj.jvp_calls += 1
239+
return obj.F, obj.JVP
240+
end
241+
156242
value(obj::NonDifferentiable{TF, TX}, x) where {TF<:AbstractArray, TX} = value(obj, copy(obj.F), x)
157243
value(obj::OnceDifferentiable{TF, TDF, TX}, x) where {TF<:AbstractArray, TDF, TX} = value(obj, copy(obj.F), x)
158244
function value(obj::AbstractObjective, F, x)
@@ -170,6 +256,19 @@ function value!!(obj::AbstractObjective, F, x)
170256
F
171257
end
172258

259+
function hv_product!(obj::TwiceDifferentiable, x, v)
260+
if x != obj.x_hv || v != obj.v_hv
261+
hv_product!!(obj, x, v)
262+
end
263+
obj.Hv
264+
end
265+
function hv_product!!(obj::TwiceDifferentiable, x, v)
266+
obj.hv_calls += 1
267+
copyto!(obj.x_hv, x)
268+
copyto!(obj.v_hv, v)
269+
obj.hv(obj.Hv, x, v)
270+
end
271+
173272
function _clear_f!(d::AbstractObjective)
174273
d.f_calls = 0
175274
if d.F isa AbstractArray
@@ -188,6 +287,18 @@ function _clear_df!(d::AbstractObjective)
188287
nothing
189288
end
190289

290+
function _clear_jvp!(d::AbstractObjective)
291+
d.jvp_calls = 0
292+
if d.JVP isa AbstractArray
293+
fill!(d.JVP, NaN)
294+
else
295+
d.JVP = NaN
296+
end
297+
fill!(d.x_jvp, NaN)
298+
fill!(d.v_jvp, NaN)
299+
nothing
300+
end
301+
191302
function _clear_h!(d::AbstractObjective)
192303
d.h_calls = 0
193304
fill!(d.H, NaN)
@@ -208,28 +319,25 @@ clear!(d::NonDifferentiable) = _clear_f!(d)
208319
function clear!(d::OnceDifferentiable)
209320
_clear_f!(d)
210321
_clear_df!(d)
322+
_clear_jvp!(d)
211323
nothing
212324
end
213325

214326
function clear!(d::TwiceDifferentiable)
215327
_clear_f!(d)
216328
_clear_df!(d)
329+
_clear_jvp!(d)
217330
_clear_h!(d)
218-
nothing
219-
end
220-
221-
function clear!(d::TwiceDifferentiableHV)
222-
_clear_f!(d)
223-
_clear_df!(d)
224331
_clear_hv!(d)
225332
nothing
226333
end
227334

335+
f_calls(d::Union{NonDifferentiable, OnceDifferentiable, TwiceDifferentiable}) = d.f_calls
228336
g_calls(d::NonDifferentiable) = 0
337+
g_calls(d::Union{OnceDifferentiable, TwiceDifferentiable}) = d.df_calls
338+
jvp_calls(d::NonDifferentiable) = 0
339+
jvp_calls(d::Union{OnceDifferentiable, TwiceDifferentiable}) = d.jvp_calls
229340
h_calls(d::Union{NonDifferentiable, OnceDifferentiable}) = 0
230-
f_calls(d) = d.f_calls
231-
g_calls(d) = d.df_calls
232-
h_calls(d) = d.h_calls
233-
hv_calls(d) = 0
234-
h_calls(d::TwiceDifferentiableHV) = 0
235-
hv_calls(d::TwiceDifferentiableHV) = d.hv_calls
341+
h_calls(d::TwiceDifferentiable) = d.h_calls
342+
hv_calls(d::Union{NonDifferentiable, OnceDifferentiable}) = 0
343+
hv_calls(d::TwiceDifferentiable) = d.hv_calls

src/objective_types/abstract.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,45 @@ function make_fdf(x, F::Number, f, g!)
1515
end
1616
end
1717

18+
# Given Julia functions of the gradient/Jacobian, create a function that calculates the Jacobian-vector product.
19+
function make_jvp(x::AbstractArray, F::AbstractArray, j!)
20+
let j! = j!, jx = alloc_DF(x, F)
21+
function jvp!(jvp, x, v)
22+
j!(jx, x)
23+
LinearAlgebra.mul!(jvp, jx, v)
24+
return jvp
25+
end
26+
end
27+
end
28+
function make_jvp(x::AbstractArray, F::Number, g!)
29+
let g! = g!, gx = alloc_DF(x, F)
30+
function jvp(x, v)
31+
g!(gx, x)
32+
return LinearAlgebra.dot(gx, v)
33+
end
34+
end
35+
end
36+
37+
# Given a Julia function of the objective function and its gradient/Jacobian, create a function that calculates the function value and the Jacobian-vector product.
38+
function make_fjvp(x::AbstractArray, F::AbstractArray, fj!)
39+
let fj! = fj!, jx = alloc_DF(x, F)
40+
function fjvp!(fx, jvp, x, v)
41+
fj!(fx, jx, x)
42+
LinearAlgebra.mul!(jvp, jx, v)
43+
return fx, jvp
44+
end
45+
end
46+
end
47+
function make_fjvp(x::AbstractArray, F::Number, fg!)
48+
let fg! = fg!, gx = alloc_DF(x, F)
49+
function jvp(x, v)
50+
fx = fg!(gx, x)
51+
return fx, LinearAlgebra.dot(gx, v)
52+
end
53+
end
54+
end
55+
56+
1857
# Initialize an n-by-n Jacobian
1958
alloc_DF(x, F) = eltype(x)(NaN) .* vec(F) .* vec(x)'
2059

0 commit comments

Comments
 (0)