|
| 1 | +""" |
| 2 | + FiniteDiff.finite_difference_jvp( |
| 3 | + f, |
| 4 | + x, |
| 5 | + v, |
| 6 | + fdtype = Val(:forward), |
| 7 | + f_in=nothing; |
| 8 | + relstep=default_relstep(fdtype, eltype(x)) |
| 9 | + absstep=relstep) |
| 10 | +""" |
| 11 | +function finite_difference_jvp( |
| 12 | + f, |
| 13 | + x, |
| 14 | + v |
| 15 | + fdtype = Val(:forward), |
| 16 | + f_in = nothing; |
| 17 | + relstep=default_relstep(eltype(x), eltype(x)), |
| 18 | + absstep=relstep, |
| 19 | + dir=true) |
| 20 | + if fdtype == Val(:complex) |
| 21 | + ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") |
| 22 | + end |
| 23 | + vecx = _vec(x) |
| 24 | + vecv = _vec(v) |
| 25 | + |
| 26 | + tmp = sqrt(dot(vecx, vecv)) |
| 27 | + epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir) |
| 28 | + if fdtype == Val(:forward) |
| 29 | + fx = f_in isa Nothing ? f(x) : f_in |
| 30 | + _x = @. x + epsilon * v |
| 31 | + fx1 = f(_x) |
| 32 | + return @. (fx1-fx)/epsilon |
| 33 | + elseif fdtype == Val(:central) |
| 34 | + _x = @. x + epsilon * v |
| 35 | + fx1 = f(_x) |
| 36 | + _x = @. x - epsilon * v |
| 37 | + fx = f(_x) |
| 38 | + return @. (fx1-fx)/(2epsilon) |
| 39 | + else |
| 40 | + fdtype_error(eltype(x)) |
| 41 | + end |
| 42 | +end |
| 43 | + |
| 44 | +""" |
| 45 | + FiniteDiff.finite_difference_jvp!( |
| 46 | + jvp::AbstractArray{<:Number}, |
| 47 | + f, |
| 48 | + x::AbstractArray{<:Number}, |
| 49 | + v, |
| 50 | + fdtype = Val(:forward), |
| 51 | + f_in=nothing, |
| 52 | + fx1 = nothing; |
| 53 | + relstep=default_relstep(fdtype, eltype(x)) |
| 54 | + absstep=relstep) |
| 55 | +""" |
| 56 | +function finite_difference_jvp!( |
| 57 | + jvp, |
| 58 | + f, |
| 59 | + x, |
| 60 | + v, |
| 61 | + fdtype = Val(:forward), |
| 62 | + f_in = nothing, |
| 63 | + fx1 = nothing; |
| 64 | + relstep = default_relstep(eltype(x), eltype(x)), |
| 65 | + absstep = relstep, |
| 66 | + dir = true) |
| 67 | + if fdtype == Val(:complex) |
| 68 | + ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") |
| 69 | + end |
| 70 | + vecx = _vec(x) |
| 71 | + vecv = _vec(v) |
| 72 | + |
| 73 | + tmp = sqrt(dot(vecx, vecv)) |
| 74 | + epsilon = compute_epsilon(fdtype, sqrt(tmp), relstep, absstep, dir) |
| 75 | + if fdtype == Val(:forward) |
| 76 | + if f_in isa Nothing |
| 77 | + fx1 = copy(jvp) |
| 78 | + f(fx1, x) |
| 79 | + else |
| 80 | + fx1 = f_in |
| 81 | + end |
| 82 | + @. x = x + epsilon * v |
| 83 | + f(jvp, x) |
| 84 | + @. jvp = (jvp-fx)/epsilon |
| 85 | + elseif fdtype == Val(:central) |
| 86 | + @. x = x - epsilon * v |
| 87 | + if fx1 isa Nothing |
| 88 | + fx1 = copy(jvp) |
| 89 | + end |
| 90 | + f(fx1, x) |
| 91 | + @. x = x + epsilon * v |
| 92 | + f(jvp, x) |
| 93 | + @. jvp = (jvp-fx1)/(2epsilon) |
| 94 | + else |
| 95 | + fdtype_error(eltype(x)) |
| 96 | + end |
| 97 | + nothing |
| 98 | +end |
0 commit comments