Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/FiniteDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ include("derivatives.jl")
include("gradients.jl")
include("jacobians.jl")
include("hessians.jl")
include("jvp.jl")

end # module
197 changes: 197 additions & 0 deletions src/jvp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
mutable struct JVPCache{X1, FX1, FDType}
x1 :: X1
fx1 :: FX1
end

"""
FiniteDiff.JVPCache(
x,
fdtype :: Type{T1} = Val{:forward})

Allocating Cache Constructor.
"""
function JVPCache(
x,
fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD}
fdtype isa Type && (fdtype = fdtype())
JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x))
end

"""
FiniteDiff.JVPCache(
x,
fx1,
fdtype :: Type{T1} = Val{:forward},

Non-Allocating Cache Constructor.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it does make a copy of x

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the non-allocating cache is expected to just use the arrays as given 😅

"""
function JVPCache(
x,
fx,
fdtype::Union{Val{FD},Type{FD}} = Val(:forward)) where {FD}
fdtype isa Type && (fdtype = fdtype())
JVPCache{typeof(x), typeof(fx), fdtype}(copy(x),fx)
end

"""
FiniteDiff.finite_difference_jvp(
f,
x :: AbstractArray{<:Number},
v :: AbstractArray{<:Number},
fdtype :: Type{T1}=Val{:central},
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep)

Cache-less.
"""
function finite_difference_jvp(f, x, v,
fdtype = Val(:forward),
f_in = nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep,
dir=true)

if f_in isa Nothing
fx = f(x)
else
fx = f_in
end
cache = JVPCache(x, fx, fdtype)
finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir)
end

"""
FiniteDiff.finite_difference_jvp(
f,
x,
v,
cache::JVPCache;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep,

Cached.
"""
function finite_difference_jvp(
f,
x,
v,
cache::JVPCache{X1, FX1, fdtype},
f_in=nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep,
dir=true) where {X1, FX1, fdtype}

if fdtype == Val(:complex)
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
end
(; x1, fx1) = cache

tmp = sqrt(abs(dot(_vec(x), _vec(v))))
epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir)
if fdtype == Val(:forward)
fx = f_in isa Nothing ? f(x) : f_in
@. x1 = x + epsilon * v
fx1 = f(x1)
@. fx1 = (fx1-fx)/epsilon
elseif fdtype == Val(:central)
@. x1 = x + epsilon * v
fx1 = f(x1)
@. x1 = x - epsilon * v
fx = f(x1)
@. fx1 = (fx1-fx)/(2epsilon)
else
fdtype_error(eltype(x))
end
fx1
end

"""
finite_difference_jvp!(
jvp::AbstractArray{<:Number},
f,
x::AbstractArray{<:Number},
v::AbstractArray{<:Number},
fdtype :: Type{T1}=Val{:forward},
returntype :: Type{T2}=eltype(x),
f_in :: Union{T2,Nothing}=nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep)

Cache-less.
"""
function finite_difference_jvp!(jvp,
f,
x,
v,
fdtype = Val(:forward),
f_in = nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep)
if !isnothing(f_in)
cache = JVPCache(x, f_in, fdtype)
elseif fdtype == Val(:forward)
fx = zero(x)
f(fx,x)
cache = JVPCache(x, fx, fdtype)
else
cache = JVPCache(x, fdtype)
end
finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep)
end

"""
FiniteDiff.finite_difference_jvp!(
jvp::AbstractArray{<:Number},
f,
x::AbstractArray{<:Number},
v::AbstractArray{<:Number},
cache::JVPCache;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep,)

Cached.
"""
function finite_difference_jvp!(
jvp,
f,
x,
v,
cache::JVPCache{X1, FX1, fdtype},
f_in = nothing;
relstep = default_relstep(fdtype, eltype(x)),
absstep = relstep,
dir = true) where {X1, FX1, fdtype}

if fdtype == Val(:complex)
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
end

(;x1, fx1) = cache
tmp = sqrt(abs(dot(_vec(x), _vec(v))))
epsilon = compute_epsilon(fdtype, tmp, relstep, absstep, dir)
if fdtype == Val(:forward)
if f_in isa Nothing
f(fx1, x)
else
fx1 = f_in
end
@. x1 = x + epsilon * v
f(jvp, x1)
@. jvp = (jvp-fx1)/epsilon
elseif fdtype == Val(:central)
@. x1 = x - epsilon * v
f(fx1, x1)
@. x1 = x + epsilon * v
f(jvp, x1)
@. jvp = (jvp-fx1)/(2epsilon)
else
fdtype_error(eltype(x))
end
nothing
end

function resize!(cache::JVPCache, i::Int)
resize!(cache.x1, i)
cache.fx1 !== nothing && resize!(cache.fx1, i)
nothing
end
50 changes: 40 additions & 10 deletions test/finitedifftests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,38 +382,68 @@ df = zero(x)
df_ref = diag(J_ref)
epsilon = zero(x)
forward_cache = FiniteDiff.JacobianCache(x, Val{:forward}, eltype(x))
forward_jvp_cache = FiniteDiff.JVPCache(x, Val{:forward})
@test forward_cache.colorvec == 1:length(x)
central_cache = FiniteDiff.JacobianCache(x, Val{:central}, eltype(x))
central_jvp_cache = FiniteDiff.JVPCache(x, Val{:central})
complex_cache = FiniteDiff.JacobianCache(x, Val{:complex}, eltype(x))
f_in = copy(y)
vdir = rand(2)
jvp_ref = J_ref*vdir

@time @testset "Out-of-Place Jacobian StridedArray real-valued tests" begin
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache), J_ref) < 1e-4
@test err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache, dir=-1), J_ref) < 1e-4
@test_throws Any err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache), J_ref) < 1e-4
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-4
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, f_in), J_ref) < 1e-4
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache), J_ref) < 1e-6
@test err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache, dir=-1), J_ref) < 1e-6
@test_throws Any err_func(FiniteDiff.finite_difference_jacobian(oopff, x, forward_cache), J_ref) < 1e-6
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-6
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, forward_cache, f_in), J_ref) < 1e-6
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, central_cache), J_ref) < 1e-8
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, Val{:central}), J_ref) < 1e-8
@test err_func(FiniteDiff.finite_difference_jacobian(oopf, x, complex_cache), J_ref) < 1e-14
end

@time @testset "Out-of-Place JVP StridedArray real-valued tests" begin
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6
@test err_func(FiniteDiff.finite_difference_jvp(oopff, x, vdir, forward_jvp_cache, dir=-1), jvp_ref) < 1e-6
@test_throws Any err_func(FiniteDiff.finite_difference_jvp(oopff, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache, relstep=sqrt(eps())), jvp_ref) < 1e-6
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, forward_jvp_cache, f_in), jvp_ref) < 1e-6
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, central_jvp_cache), jvp_ref) < 1e-8
@test err_func(FiniteDiff.finite_difference_jvp(oopf, x, vdir, Val{:central}), jvp_ref) < 1e-8
end

function test_iipJac(J_ref, args...; kwargs...)
_J = zero(J_ref)
FiniteDiff.finite_difference_jacobian!(_J, args...; kwargs...)
_J
end
@time @testset "inPlace Jacobian StridedArray real-valued tests" begin
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache), J_ref) < 1e-4
@test err_func(test_iipJac(J_ref, iipff, x, forward_cache, dir=-1), J_ref) < 1e-4
@test_throws Any err_func(test_iipJac(J_ref, iipff, x, forward_cache), J_ref) < 1e-4
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-4
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache, f_in), J_ref) < 1e-4
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache), J_ref) < 1e-6
@test err_func(test_iipJac(J_ref, iipff, x, forward_cache, dir=-1), J_ref) < 1e-6
@test_throws Any err_func(test_iipJac(J_ref, iipff, x, forward_cache), J_ref) < 1e-6
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache, relstep=sqrt(eps())), J_ref) < 1e-6
@test err_func(test_iipJac(J_ref, iipf, x, forward_cache, f_in), J_ref) < 1e-6
@test err_func(test_iipJac(J_ref, iipf, x, central_cache), J_ref) < 1e-8
@test err_func(test_iipJac(J_ref, iipf, x, Val{:central}), J_ref) < 1e-8
@test err_func(test_iipJac(J_ref, iipf, x, complex_cache), J_ref) < 1e-14
end

function test_iipJVP(jvp_ref, args...; kwargs...)
_jvp = zero(jvp_ref)
FiniteDiff.finite_difference_jvp!(_jvp, args...; kwargs...)
_jvp
end

@time @testset "inPlace JVP StridedArray real-valued tests" begin
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6
@test err_func(test_iipJVP(jvp_ref, iipff, x, vdir, forward_jvp_cache, dir=-1), jvp_ref) < 1e-6
@test_throws Any err_func(test_iipJVP(jvp_ref, iipff, x, vdir, forward_jvp_cache), jvp_ref) < 1e-6
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache, relstep=sqrt(eps())), jvp_ref) < 1e-6
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, forward_jvp_cache, f_in), jvp_ref) < 1e-6
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, central_jvp_cache), jvp_ref) < 1e-8
@test err_func(test_iipJVP(jvp_ref, iipf, x, vdir, Val{:central}), jvp_ref) < 1e-8
end

function iipf(fvec, x)
fvec[1] = (im * x[1] + 3) * (x[2]^3 - 7) + 18
fvec[2] = sin(x[2] * exp(x[1]) - 1)
Expand Down
Loading