Skip to content

Commit 745069d

Browse files
committed
Add support for Jacobian-vector products
1 parent 52f854c commit 745069d

File tree

6 files changed

+481
-24
lines changed

6 files changed

+481
-24
lines changed

src/NLSolversBase.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@ export AbstractObjective,
1414
value!,
1515
value_gradient!,
1616
value_jacobian!,
17+
value_jvp!,
1718
gradient,
1819
gradient!,
1920
jacobian,
2021
jacobian!,
22+
jvp,
23+
jvp!,
2124
hessian,
2225
hessian!,
2326
value!!,
2427
value_gradient!!,
2528
value_jacobian!!,
29+
value_jvp!!,
2630
hessian!!,
2731
hv_product,
2832
hv_product!,
@@ -38,6 +42,7 @@ export AbstractObjective,
3842
clear!,
3943
f_calls,
4044
g_calls,
45+
jvp_calls,
4146
h_calls,
4247
hv_calls
4348

src/interface.jl

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,126 @@ function jacobian(obj::AbstractObjective, x)
153153
return newdf
154154
end
155155

156+
"""
157+
jvp(obj::OnceDifferentiable)
158+
159+
Return the most recently evaluated Jacobian-vector product of the objective function `obj`.
160+
"""
161+
jvp(obj::OnceDifferentiable) = obj.JVP
162+
163+
"""
164+
jvp(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
165+
166+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`.
167+
168+
!!! note
169+
This function does neither use cached results nor cache the results.
170+
"""
171+
function jvp(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
172+
obj.jvp_calls += 1
173+
return obj.jvp(x, v)
174+
end
175+
176+
"""
177+
jvp!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
178+
179+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
180+
and cache the results in `obj`.
181+
182+
!!! note
183+
This function does use cached results if available.
184+
"""
185+
function jvp!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
186+
if x != obj.x_jvp || v != obj.v_jvp
187+
jvp!!(obj, x, v)
188+
end
189+
return jvp(obj)
190+
end
191+
192+
"""
193+
jvp!!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
194+
195+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
196+
and cache the results in `obj`.
197+
198+
!!! note
199+
This function does not use cached results but forces reevaluation of the Jacobian-vector product.
200+
"""
201+
function jvp!!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
202+
if obj.JVP isa Real
203+
obj.JVP = obj.jvp(x, v)
204+
else
205+
obj.jvp(obj.JVP, x, v)
206+
end
207+
copyto!(obj.x_jvp, x)
208+
copyto!(obj.v_jvp, v)
209+
obj.jvp_calls += 1
210+
return jvp(obj)
211+
end
212+
213+
"""
214+
value_jvp(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
215+
216+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`.
217+
218+
!!! note
219+
This function does neither use cached results nor cache the results.
220+
"""
221+
function value_jvp(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
222+
obj.f_calls += 1
223+
obj.jvp_calls += 1
224+
return obj.fjvp(x, v)
225+
end
226+
227+
"""
228+
value_jvp!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
229+
230+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
231+
and cache the results in `obj`.
232+
233+
!!! note
234+
This function does use cached results if available.
235+
"""
236+
function value_jvp!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
237+
if x != obj.x_f
238+
if x != obj.x_jvp || v != obj.v_jvp
239+
# Both value and Jacobian-vector product have to be recomputed
240+
value_jvp!!(obj, x, v)
241+
else
242+
# Only value has to be recomputed
243+
value!!(obj, x)
244+
end
245+
elseif x != obj.x_jvp || v != obj.v_jvp
246+
jvp!!(obj, x, v)
247+
end
248+
return obj.F, obj.JVP
249+
end
250+
251+
"""
252+
value_jvp!!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
253+
254+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
255+
and cache the results in `obj`.
256+
257+
!!! note
258+
This function does not use cached results but forces reevaluation of the Jacobian-vector product.
259+
"""
260+
function value_jvp!!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
261+
if obj.F isa Real
262+
y, ty = obj.fjvp(x, v)
263+
obj.F = y
264+
obj.JVP = ty
265+
else
266+
obj.fjvp(obj.F, obj.JVP, x, v)
267+
end
268+
copyto!(obj.x_f, x)
269+
copyto!(obj.x_jvp, x)
270+
copyto!(obj.v_jvp, v)
271+
obj.f_calls += 1
272+
obj.jvp_calls += 1
273+
return obj.F, obj.JVP
274+
end
275+
156276
value(obj::NonDifferentiable{TF, TX}, x) where {TF<:AbstractArray, TX} = value(obj, copy(obj.F), x)
157277
value(obj::OnceDifferentiable{TF, TDF, TX}, x) where {TF<:AbstractArray, TDF, TX} = value(obj, copy(obj.F), x)
158278
function value(obj::AbstractObjective, F, x)
@@ -188,6 +308,18 @@ function _clear_df!(d::AbstractObjective)
188308
nothing
189309
end
190310

311+
function _clear_jvp!(d::AbstractObjective)
312+
d.jvp_calls = 0
313+
if d.JVP isa AbstractArray
314+
fill!(d.JVP, NaN)
315+
else
316+
d.JVP = NaN
317+
end
318+
fill!(d.x_jvp, NaN)
319+
fill!(d.v_jvp, NaN)
320+
nothing
321+
end
322+
191323
function _clear_h!(d::AbstractObjective)
192324
d.h_calls = 0
193325
fill!(d.H, NaN)
@@ -208,6 +340,7 @@ clear!(d::NonDifferentiable) = _clear_f!(d)
208340
function clear!(d::OnceDifferentiable)
209341
_clear_f!(d)
210342
_clear_df!(d)
343+
_clear_jvp!(d)
211344
nothing
212345
end
213346

@@ -229,6 +362,8 @@ g_calls(d::NonDifferentiable) = 0
229362
h_calls(d::Union{NonDifferentiable, OnceDifferentiable}) = 0
230363
f_calls(d) = d.f_calls
231364
g_calls(d) = d.df_calls
365+
jvp_calls(d) = 0
366+
jvp_calls(d::OnceDifferentiable) = d.jvp_calls
232367
h_calls(d) = d.h_calls
233368
hv_calls(d) = 0
234369
h_calls(d::TwiceDifferentiableHV) = 0

src/objective_types/abstract.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,45 @@ function make_fdf(x, F::Number, f, g!)
1515
end
1616
end
1717

18+
# Given Julia functions of the gradient/Jacobian, create a function that calculates the Jacobian-vector product.
19+
function make_jvp(x::AbstractArray, F::AbstractArray, j!)
20+
let j! = j!, jx = alloc_DF(x, F)
21+
function jvp!(jvp, x, v)
22+
j!(jx, x)
23+
LinearAlgebra.mul!(jvp, jx, v)
24+
return jvp
25+
end
26+
end
27+
end
28+
function make_jvp(x::AbstractArray, F::Number, g!)
29+
let g! = g!, gx = alloc_DF(x, F)
30+
function jvp(x, v)
31+
g!(gx, x)
32+
return LinearAlgebra.dot(gx, v)
33+
end
34+
end
35+
end
36+
37+
# Given a Julia function of the objective function and its gradient/Jacobian, create a function that calculates the function value and the Jacobian-vector product.
38+
function make_fjvp(x::AbstractArray, F::AbstractArray, fj!)
39+
let fj! = fj!, jx = alloc_DF(x, F)
40+
function fjvp!(fx, jvp, x, v)
41+
fj!(fx, jx, x)
42+
LinearAlgebra.mul!(jvp, jx, v)
43+
return fx, jvp
44+
end
45+
end
46+
end
47+
function make_fjvp(x::AbstractArray, F::Number, fg!)
48+
let fg! = fg!, gx = alloc_DF(x, F)
49+
function jvp(x, v)
50+
fx = fg!(gx, x)
51+
return fx, LinearAlgebra.dot(gx, v)
52+
end
53+
end
54+
end
55+
56+
1857
# Initialize an n-by-n Jacobian
1958
alloc_DF(x, F) = eltype(x)(NaN) .* vec(F) .* vec(x)'
2059

0 commit comments

Comments
 (0)