Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ function that is being called.

`DiffCache` is a type for doubly-preallocated vectors which are
compatible with non-allocating forward-mode automatic differentiation by
ForwardDiff.jl. Since ForwardDiff uses chunked duals in its forward pass, two
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
Since ForwardDiff.jl uses chunked duals in its forward pass, two
vector sizes are required in order for the arrays to be properly defined.
`DiffCache` creates a dispatching type to solve this, so that by passing a
qualifier it can automatically switch between the required cache. This method
is fully type-stable and non-dynamic, made for when the highest performance is
needed.

The `DiffCache` also supports sparsity detection via
[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/).
However, the implementation may allocate memory in this case since we assume that
sparsity detection happens only once (or maybe a few times). Allocating memory
allows to save memory in the long run since no additional cache needs to be stored
forever.

### Using DiffCache

```julia
Expand Down
49 changes: 13 additions & 36 deletions ext/PreallocationToolsSparseConnectivityTracerExt.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,23 @@
module PreallocationToolsSparseConnectivityTracerExt

using PreallocationTools
isdefined(Base, :get_extension) ? (import SparseConnectivityTracer) :
(import ..SparseConnectivityTracer)
using PreallocationTools: PreallocationTools, DiffCache, get_tmp
using SparseConnectivityTracer: AbstractTracer, Dual

function PreallocationTools.get_tmp(
dc::DiffCache, u::T) where {T <: SparseConnectivityTracer.Dual}
if isbitstype(T)
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
if nelem > length(dc.dual_du)
PreallocationTools.enlargediffcache!(dc, nelem)
end
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
else
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
end
function PreallocationTools.get_tmp(dc::DiffCache, u::T) where {T <: Union{AbstractTracer, Dual}}
return get_tmp(dc, typeof(u))
end

function PreallocationTools.get_tmp(
dc::DiffCache, ::Type{T}) where {T <: SparseConnectivityTracer.Dual}
if isbitstype(T)
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
if nelem > length(dc.dual_du)
PreallocationTools.enlargediffcache!(dc, nelem)
end
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
else
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
end
function PreallocationTools.get_tmp(dc::DiffCache, u::AbstractArray{<:T}) where {T <: Union{AbstractTracer, Dual}}
return get_tmp(dc, eltype(u))
end

function PreallocationTools.get_tmp(
dc::DiffCache, u::AbstractArray{T}) where {T <: SparseConnectivityTracer.Dual}
if isbitstype(T)
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
if nelem > length(dc.dual_du)
PreallocationTools.enlargediffcache!(dc, nelem)
end
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
else
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
end
function PreallocationTools.get_tmp(dc::DiffCache, ::Type{T}) where {T <: Union{AbstractTracer, Dual}}
# We allocate memory here since we assume that sparsity connection happens only
# once (or maybe a few times). This simplifies the implementation and allows us
# to save memory in the long run since we do not need to store an additional
# cache for the sparsity detection that would be used only once but carried
# around forever.
return similar(dc.du, T)
end

end
8 changes: 6 additions & 2 deletions src/PreallocationTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ end

Builds a `DiffCache` object that stores both a version of the cache for `u`
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
forward-mode automatic differentiation. Supports nested AD via keyword `levels`
or specifying an array of chunk_sizes.
forward-mode automatic differentiation via
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
Supports nested AD via keyword `levels` or specifying an array of chunk sizes.

The `DiffCache` also supports sparsity detection via
[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/).
"""
function DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u));
levels::Int = 1)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "DiffCache Resizing" include("core_resizing.jl")
@safetestset "DiffCache Nested Duals" include("core_nesteddual.jl")
@safetestset "DiffCache Sparsity Support" include("sparsity_support.jl")
@safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl")
@safetestset "LazyBufferCache" include("lbc.jl")
@safetestset "GeneralLazyBufferCache" include("general_lbc.jl")
end
Expand Down
63 changes: 63 additions & 0 deletions test/sparse_connectivity_tracer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
module TestSparseConnectivityTracer

using PreallocationTools, SparseConnectivityTracer, ForwardDiff, SparseArrays, Test

function f1(u, cache)
c = get_tmp(cache, u)
# This will throw if a fallback definition is used
# such that `eltype(c) == Any`
T = eltype(c)
@. c = u^2 + one(T)
return sum(c)
end

@testset "out of place" begin
u = rand(10)
cache = DiffCache(u)

@test_nowarn @inferred f1(u, cache)
@test_nowarn ForwardDiff.gradient(u) do u
f1(u, cache)
end
@test_nowarn jacobian_sparsity(u, TracerSparsityDetector()) do u
f1(u, cache)
end
@test_nowarn hessian_sparsity(u, TracerSparsityDetector()) do u
f1(u, cache)
end
@test_nowarn jacobian_sparsity(u, TracerLocalSparsityDetector()) do u
f1(u, cache)
end
@test_nowarn hessian_sparsity(u, TracerLocalSparsityDetector()) do u
f1(u, cache)
end
end

function f1!(du, u, cache)
c = get_tmp(cache, u)
# This will throw if a fallback definition is used
# such that `eltype(c) == Any`
T = eltype(c)
@. c = u^2 + one(T)
du[1] = sum(c)
return nothing
end

@testset "in place" begin
u = rand(10)
cache = DiffCache(u)
du = similar(u, (1,))

@test_nowarn @inferred f1!(du, u, cache)
@test_nowarn ForwardDiff.jacobian(du, u) do du, u
f1!(du, u, cache)
end
@test_nowarn jacobian_sparsity(du, u, TracerSparsityDetector()) do du, u
f1!(du, u, cache)
end
@test_nowarn jacobian_sparsity(du, u, TracerLocalSparsityDetector()) do du, u
f1!(du, u, cache)
end
end

end
Loading