Skip to content

Commit 1b405fc

Browse files
committed
Add support for Jacobian-vector products
1 parent 9d4e9c7 commit 1b405fc

File tree

4 files changed

+233
-28
lines changed

4 files changed

+233
-28
lines changed

src/NLSolversBase.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@ export AbstractObjective,
1515
value!,
1616
value_gradient!,
1717
value_jacobian!,
18+
value_jvp!,
1819
gradient,
1920
gradient!,
2021
jacobian,
2122
jacobian!,
23+
jvp,
24+
jvp!,
2225
hessian,
2326
hessian!,
2427
value!!,
2528
value_gradient!!,
2629
value_jacobian!!,
30+
value_jvp!!,
2731
hessian!!,
2832
hv_product,
2933
hv_product!,
@@ -39,6 +43,7 @@ export AbstractObjective,
3943
clear!,
4044
f_calls,
4145
g_calls,
46+
jvp_calls,
4247
h_calls,
4348
hv_calls
4449

src/interface.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,127 @@ function jacobian(obj::AbstractObjective, x)
153153
return newdf
154154
end
155155

156+
"""
157+
jvp(obj::OnceDifferentiable)
158+
159+
Return the most recently evaluated Jacobian-vector product of the objective function `obj`.
160+
"""
161+
jvp(obj::OnceDifferentiable) = obj.JVP
162+
163+
"""
164+
jvp(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
165+
166+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`.
167+
168+
!!! note
169+
This function does neither use cached results nor cache the results.
170+
"""
171+
function jvp(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
172+
obj.jvp_calls += 1
173+
return obj.jvp(x, v)
174+
end
175+
176+
"""
177+
jvp!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
178+
179+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
180+
and cache the results in `obj`.
181+
182+
!!! note
183+
This function does use cached results if available.
184+
"""
185+
function jvp!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
186+
if x != obj.x_jvp || v != obj.v_jvp
187+
jvp!!(obj, x, v)
188+
end
189+
return jvp(obj)
190+
end
191+
192+
"""
193+
jvp!!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
194+
195+
Return the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
196+
and cache the results in `obj`.
197+
198+
!!! note
199+
This function does not use cached results but forces reevaluation of the Jacobian-vector product.
200+
"""
201+
function jvp!!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
202+
if obj.JVP isa Real
203+
z = obj.jvp(x, v)
204+
obj.JVP = z
205+
else
206+
obj.jvp(obj.JVP, x, v)
207+
end
208+
copyto!(obj.x_jvp, x)
209+
copyto!(obj.v_jvp, v)
210+
obj.jvp_calls += 1
211+
return z
212+
end
213+
214+
"""
215+
value_jvp(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
216+
217+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`.
218+
219+
!!! note
220+
This function does neither use cached results nor cache the results.
221+
"""
222+
function value_jvp(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
223+
obj.f_calls += 1
224+
obj.jvp_calls += 1
225+
return obj.fjvp(x, v)
226+
end
227+
228+
"""
229+
value_jvp!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
230+
231+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
232+
and cache the results in `obj`.
233+
234+
!!! note
235+
This function does use cached results if available.
236+
"""
237+
function value_jvp!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
238+
if x != obj.x_f
239+
if x != obj.x_jvp || v != obj.v_jvp
240+
# Both value and Jacobian-vector product have to be recomputed
241+
value_jvp!!(obj, x, v)
242+
else
243+
# Only value has to be recomputed
244+
value!!(obj, x)
245+
end
246+
elseif x != obj.x_jvp || v != obj.v_jvp
247+
jvp!!(obj, x, v)
248+
end
249+
return obj.F, obj.JVP
250+
end
251+
252+
"""
253+
value_jvp!!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
254+
255+
Return the value and the Jacobian-vector product of the objective function `obj` at point `x` with tangents `v`,
256+
and cache the results in `obj`.
257+
258+
!!! note
259+
This function does not use cached results but forces reevaluation of the Jacobian-vector product.
260+
"""
261+
function value_jvp!!(obj::OnceDifferentiable, x::AbstractArray, v::AbstractArray)
262+
if obj.F isa Real
263+
y, ty = obj.fjvp(x, v)
264+
obj.F = y
265+
obj.JVP = ty
266+
else
267+
obj.fjvp(obj.F, obj.JVP, x, v)
268+
end
269+
copyto!(obj.x_f, x)
270+
copyto!(obj.x_jvp, x)
271+
copyto!(obj.v_jvp, v)
272+
obj.f_calls += 1
273+
obj.jvp_calls += 1
274+
return obj.F, obj.JVP
275+
end
276+
156277
value(obj::NonDifferentiable{TF, TX}, x) where {TF<:AbstractArray, TX} = value(obj, copy(obj.F), x)
157278
value(obj::OnceDifferentiable{TF, TDF, TX}, x) where {TF<:AbstractArray, TDF, TX} = value(obj, copy(obj.F), x)
158279
function value(obj::AbstractObjective, F, x)
@@ -188,6 +309,14 @@ function _clear_df!(d::AbstractObjective)
188309
nothing
189310
end
190311

312+
function _clear_jvp!(d::AbstractObjective)
313+
d.jvp_calls = 0
314+
fill!(d.JVP, NaN)
315+
fill!(d.x_jvp, NaN)
316+
fill!(d.v_jvp, NaN)
317+
nothing
318+
end
319+
191320
function _clear_h!(d::AbstractObjective)
192321
d.h_calls = 0
193322
fill!(d.H, NaN)
@@ -208,6 +337,7 @@ clear!(d::NonDifferentiable) = _clear_f!(d)
208337
function clear!(d::OnceDifferentiable)
209338
_clear_f!(d)
210339
_clear_df!(d)
340+
_clear_jvp!(d)
211341
nothing
212342
end
213343

@@ -229,6 +359,8 @@ g_calls(d::NonDifferentiable) = 0
229359
h_calls(d::Union{NonDifferentiable, OnceDifferentiable}) = 0
230360
f_calls(d) = d.f_calls
231361
g_calls(d) = d.df_calls
362+
jvp_calls(d) = 0
363+
jvp_calls(d::OnceDifferentiable) = d.jvp_calls
232364
h_calls(d) = d.h_calls
233365
hv_calls(d) = 0
234366
h_calls(d::TwiceDifferentiableHV) = 0

src/objective_types/inplace_factory.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,29 @@ function fdf!_from_fdf(fj, F::AbstractArray, inplace)
5555
end
5656
end
5757
end
58+
59+
function jvp!_from_jvp(jvp, F::AbstractArray, inplace::Bool)
60+
if inplace
61+
return jvp
62+
else
63+
return function jvp!(JVP, x, v)
64+
copyto!(JVP, jvp(x, v))
65+
end
66+
end
67+
end
68+
function fjvp!_from_fjvp(fjvp, F::AbstractArray, inplace::Bool)
69+
if inplace
70+
return fjvp
71+
else
72+
return function fjvp!(F, JVP, x, v)
73+
fx, jvpxv = fjvp(x, v)
74+
copyto!(F, fx)
75+
copyto!(JVP, jvpxv)
76+
return fx, jvpxv
77+
end
78+
end
79+
end
80+
5881
function h!_from_h(h, F::Real, inplace)
5982
if inplace
6083
return h

0 commit comments

Comments
 (0)