Skip to content

Commit aa5b954

Browse files
author
oscarddssmith
committed
actually add cache
1 parent b0922f6 commit aa5b954

File tree

1 file changed

+144
-45
lines changed

1 file changed

+144
-45
lines changed

src/jvp.jl

Lines changed: 144 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,197 @@
1+
mutable struct JVPCache{X1, FX1, FDType}
2+
x1 :: X1
3+
fx1 :: FX1
4+
end
5+
6+
"""
7+
FiniteDiff.JVPCache(
8+
x,
9+
fdtype :: Type{T1} = Val{:forward})
10+
11+
Allocating Cache Constructor.
12+
"""
13+
function JVPCache(
14+
x,
15+
fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD}
16+
fdtype isa Type && (fdtype = fdtype())
17+
JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x))
18+
end
19+
20+
"""
21+
FiniteDiff.JVPCache(
22+
x,
23+
fx1,
24+
fdtype :: Type{T1} = Val{:forward},
25+
26+
Non-Allocating Cache Constructor.
27+
"""
28+
function JVPCache(
29+
x,
30+
fx,
31+
fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD}
32+
fdtype isa Type && (fdtype = fdtype())
33+
JVPCache{typeof(x), typeof(fx), fdtype}(copy(x),fx)
34+
end
35+
36+
"""
37+
FiniteDiff.finite_difference_jvp(
38+
f,
39+
x :: AbstractArray{<:Number},
40+
v :: AbstractArray{<:Number},
41+
fdtype :: Type{T1}=Val{:central},
42+
relstep=default_relstep(fdtype, eltype(x)),
43+
absstep=relstep)
44+
45+
Cache-less.
46+
"""
47+
function finite_difference_jvp(f, x, v,
48+
fdtype = Val(:forward),
49+
f_in = nothing;
50+
relstep=default_relstep(fdtype, eltype(x)),
51+
absstep=relstep,
52+
dir=true)
53+
54+
if f_in isa Nothing
55+
fx = f(x)
56+
else
57+
fx = f_in
58+
end
59+
cache = JVPCache(x, fx, fdtype)
60+
finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir)
61+
end
62+
163
"""
264
FiniteDiff.finite_difference_jvp(
365
f,
466
x,
567
v,
6-
fdtype = Val(:forward),
7-
f_in=nothing;
8-
relstep=default_relstep(fdtype, eltype(x))
9-
absstep=relstep)
68+
cache::JVPCache;
69+
relstep=default_relstep(fdtype, eltype(x)),
70+
absstep=relstep,
71+
72+
Cached.
1073
"""
1174
function finite_difference_jvp(
1275
f,
1376
x,
1477
v,
15-
fdtype = Val(:forward),
16-
f_in = nothing;
17-
relstep=default_relstep(eltype(x), eltype(x)),
78+
cache::JVPCache{X1, FX1, fdtype},
79+
f_in=nothing;
80+
relstep=default_relstep(fdtype, eltype(x)),
1881
absstep=relstep,
19-
dir=true)
82+
dir=true) where {X1, FX1, fdtype}
83+
2084
if fdtype == Val(:complex)
2185
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
2286
end
23-
vecx = _vec(x)
24-
vecv = _vec(v)
87+
(; x1, fx1) = cache
2588

26-
tmp = sqrt(dot(vecx, vecv))
27-
epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir)
89+
tmp = sqrt(dot(_vec(x), _vec(v)))
90+
epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir)
2891
if fdtype == Val(:forward)
2992
fx = f_in isa Nothing ? f(x) : f_in
30-
_x = @. x + epsilon * v
31-
fx1 = f(_x)
32-
return @. (fx1-fx)/epsilon
93+
@. x1 = x + epsilon * v
94+
fx1 = f(x1)
95+
@. fx1 = (fx1-fx)/epsilon
3396
elseif fdtype == Val(:central)
34-
_x = @. x + epsilon * v
35-
fx1 = f(_x)
36-
_x = @. x - epsilon * v
37-
fx = f(_x)
38-
return @. (fx1-fx)/(2epsilon)
97+
@. x1 = x + epsilon * v
98+
fx1 = f(x1)
99+
@. x1 = x - epsilon * v
100+
fx = f(x1)
101+
@. fx1 = (fx1-fx)/(2epsilon)
39102
else
40103
fdtype_error(eltype(x))
41104
end
105+
fx1
42106
end
43107

44108
"""
45-
FiniteDiff.finite_difference_jvp!(
109+
finite_difference_jvp!(
46110
jvp::AbstractArray{<:Number},
47111
f,
48112
x::AbstractArray{<:Number},
49-
v,
50-
fdtype = Val(:forward),
51-
f_in=nothing,
52-
fx1 = nothing;
53-
relstep=default_relstep(fdtype, eltype(x))
113+
v::AbstractArray{<:Number},
114+
fdtype :: Type{T1}=Val{:forward},
115+
returntype :: Type{T2}=eltype(x),
116+
f_in :: Union{T2,Nothing}=nothing;
117+
relstep=default_relstep(fdtype, eltype(x)),
54118
absstep=relstep)
119+
120+
Cache-less.
121+
"""
122+
function finite_difference_jvp!(jvp,
123+
f,
124+
x,
125+
v,
126+
fdtype = Val(:forward),
127+
f_in = nothing;
128+
relstep=default_relstep(fdtype, eltype(x)),
129+
absstep=relstep)
130+
if !isnothing(f_in)
131+
cache = JVPCache(x, f_in, fdtype)
132+
elseif fdtype == Val(:forward)
133+
fx = zero(x)
134+
f(fx,x)
135+
cache = JVPCache(x, fx, fdtype)
136+
else
137+
cache = JVPCache(x, fdtype)
138+
end
139+
finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep)
140+
end
141+
142+
"""
143+
FiniteDiff.finite_difference_jvp!(
144+
jvp::AbstractArray{<:Number},
145+
f,
146+
x::AbstractArray{<:Number},
147+
v::AbstractArray{<:Number},
148+
cache::JVPCache;
149+
relstep=default_relstep(fdtype, eltype(x)),
150+
absstep=relstep,)
151+
152+
Cached.
55153
"""
56154
function finite_difference_jvp!(
57155
jvp,
58156
f,
59157
x,
60158
v,
61-
fdtype = Val(:forward),
62-
f_in = nothing,
63-
fx1 = nothing;
64-
relstep = default_relstep(eltype(x), eltype(x)),
159+
cache::JVPCache{X1, FX1, fdtype},
160+
f_in = nothing;
161+
relstep = default_relstep(fdtype, eltype(x)),
65162
absstep = relstep,
66-
dir = true)
163+
dir = true) where {X1, FX1, fdtype}
164+
67165
if fdtype == Val(:complex)
68166
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
69167
end
70-
vecx = _vec(x)
71-
vecv = _vec(v)
72168

73-
tmp = sqrt(dot(vecx, vecv))
74-
epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir)
169+
(;x1, fx1) = cache
170+
tmp = sqrt(dot(_vec(x), _vec(v)))
171+
epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir)
75172
if fdtype == Val(:forward)
76173
if f_in isa Nothing
77-
fx1 = copy(jvp)
78174
f(fx1, x)
79175
else
80176
fx1 = f_in
81177
end
82-
@. x = x + epsilon * v
83-
f(jvp, x)
178+
@. x1 = x + epsilon * v
179+
f(jvp, x1)
84180
@. jvp = (jvp-fx)/epsilon
85181
elseif fdtype == Val(:central)
86-
@. x = x - epsilon * v
87-
if fx1 isa Nothing
88-
fx1 = copy(jvp)
89-
end
90-
f(fx1, x)
91-
@. x = x + epsilon * v
92-
f(jvp, x)
182+
@. x1 = x - epsilon * v
183+
f(fx1, x1)
184+
@. x1 = x + epsilon * v
185+
f(jvp, x1)
93186
@. jvp = (jvp-fx1)/(2epsilon)
94187
else
95188
fdtype_error(eltype(x))
96189
end
97190
nothing
98191
end
192+
193+
function resize!(cache::JVPCache, i::Int)
194+
resize!(cache.x1, i)
195+
cache.fx1 !== nothing && resize!(cache.fx1, i)
196+
nothing
197+
end

0 commit comments

Comments
 (0)