Skip to content

Commit dcabf27

Browse files
authored
add finite_difference_jvp
Add the pushforward operation with implementation taken from jacobian but simplified.
1 parent b574440 commit dcabf27

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

src/jvp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
FiniteDiff.finite_difference_jvp(
3+
f,
4+
x,
5+
v,
6+
fdtype = Val(:forward),
7+
f_in=nothing;
8+
relstep=default_relstep(fdtype, eltype(x))
9+
absstep=relstep)
10+
"""
11+
function finite_difference_jvp(
12+
f,
13+
x,
14+
v
15+
fdtype = Val(:forward),
16+
f_in = nothing;
17+
relstep=default_relstep(eltype(x), eltype(x)),
18+
absstep=relstep,
19+
dir=true)
20+
if fdtype == Val(:complex)
21+
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
22+
end
23+
vecx = _vec(x)
24+
vecv = _vec(v)
25+
26+
tmp = sqrt(dot(vecx, vecv))
27+
epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir)
28+
if fdtype == Val(:forward)
29+
fx = f_in isa Nothing ? f(x) : f_in
30+
_x = @. x + epsilon * v
31+
fx1 = f(_x)
32+
return @. (fx1-fx)/epsilon
33+
elseif fdtype == Val(:central)
34+
_x = @. x + epsilon * v
35+
fx1 = f(_x)
36+
_x = @. x - epsilon * v
37+
fx = f(_x)
38+
return @. (fx1-fx)/(2epsilon)
39+
else
40+
fdtype_error(eltype(x))
41+
end
42+
end
43+
44+
"""
45+
FiniteDiff.finite_difference_jvp!(
46+
jvp::AbstractArray{<:Number},
47+
f,
48+
x::AbstractArray{<:Number},
49+
v,
50+
fdtype = Val(:forward),
51+
f_in=nothing,
52+
fx1 = nothing;
53+
relstep=default_relstep(fdtype, eltype(x))
54+
absstep=relstep)
55+
"""
56+
function finite_difference_jvp!(
57+
jvp,
58+
f,
59+
x,
60+
v,
61+
fdtype = Val(:forward),
62+
f_in = nothing,
63+
fx1 = nothing;
64+
relstep = default_relstep(eltype(x), eltype(x)),
65+
absstep = relstep,
66+
dir = true)
67+
if fdtype == Val(:complex)
68+
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
69+
end
70+
vecx = _vec(x)
71+
vecv = _vec(v)
72+
73+
tmp = sqrt(dot(vecx, vecv))
74+
epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir)
75+
if fdtype == Val(:forward)
76+
if f_in isa Nothing
77+
fx1 = copy(jvp)
78+
f(fx1, x)
79+
else
80+
fx1 = f_in
81+
end
82+
@. x = x + epsilon * v
83+
f(jvp, x)
84+
@. jvp = (jvp-fx)/epsilon
85+
elseif fdtype == Val(:central)
86+
@. x = x - epsilon * v
87+
if fx1 isa Nothing
88+
fx1 = copy(jvp)
89+
end
90+
f(fx1, x)
91+
@. x = x + epsilon * v
92+
f(jvp, x)
93+
@. jvp = (jvp-fx1)/(2epsilon)
94+
else
95+
fdtype_error(eltype(x))
96+
end
97+
nothing
98+
end

0 commit comments

Comments
 (0)