diff --git a/README.md b/README.md index bd19d24..6ecdc90 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,21 @@ function that is being called. `DiffCache` is a type for doubly-preallocated vectors which are compatible with non-allocating forward-mode automatic differentiation by -ForwardDiff.jl. Since ForwardDiff uses chunked duals in its forward pass, two +[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). +Since ForwardDiff.jl uses chunked duals in its forward pass, two vector sizes are required in order for the arrays to be properly defined. `DiffCache` creates a dispatching type to solve this, so that by passing a qualifier it can automatically switch between the required cache. This method is fully type-stable and non-dynamic, made for when the highest performance is needed. +The `DiffCache` also supports sparsity detection via +[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/). +However, the implementation may allocate memory in this case since we assume that +sparsity detection happens only once (or maybe a few times). Allocating memory +allows to save memory in the long run since no additional cache needs to be stored +forever. + ### Using DiffCache ```julia diff --git a/ext/PreallocationToolsSparseConnectivityTracerExt.jl b/ext/PreallocationToolsSparseConnectivityTracerExt.jl index 1701eab..a0242c2 100644 --- a/ext/PreallocationToolsSparseConnectivityTracerExt.jl +++ b/ext/PreallocationToolsSparseConnectivityTracerExt.jl @@ -1,46 +1,23 @@ module PreallocationToolsSparseConnectivityTracerExt -using PreallocationTools -isdefined(Base, :get_extension) ? (import SparseConnectivityTracer) : -(import ..SparseConnectivityTracer) +using PreallocationTools: PreallocationTools, DiffCache, get_tmp +using SparseConnectivityTracer: AbstractTracer, Dual -function PreallocationTools.get_tmp( - dc::DiffCache, u::T) where {T <: SparseConnectivityTracer.Dual} - if isbitstype(T) - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) - if nelem > length(dc.dual_du) - PreallocationTools.enlargediffcache!(dc, nelem) - end - PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) - else - PreallocationTools._restructure(dc.du, zeros(T, size(dc.du))) - end +function PreallocationTools.get_tmp(dc::DiffCache, u::T) where {T <: Union{AbstractTracer, Dual}} + return get_tmp(dc, typeof(u)) end -function PreallocationTools.get_tmp( - dc::DiffCache, ::Type{T}) where {T <: SparseConnectivityTracer.Dual} - if isbitstype(T) - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) - if nelem > length(dc.dual_du) - PreallocationTools.enlargediffcache!(dc, nelem) - end - PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) - else - PreallocationTools._restructure(dc.du, zeros(T, size(dc.du))) - end +function PreallocationTools.get_tmp(dc::DiffCache, u::AbstractArray{<:T}) where {T <: Union{AbstractTracer, Dual}} + return get_tmp(dc, eltype(u)) end -function PreallocationTools.get_tmp( - dc::DiffCache, u::AbstractArray{T}) where {T <: SparseConnectivityTracer.Dual} - if isbitstype(T) - nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) - if nelem > length(dc.dual_du) - PreallocationTools.enlargediffcache!(dc, nelem) - end - PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) - else - PreallocationTools._restructure(dc.du, zeros(T, size(dc.du))) - end +function PreallocationTools.get_tmp(dc::DiffCache, ::Type{T}) where {T <: Union{AbstractTracer, Dual}} + # We allocate memory here since we assume that sparsity connection happens only + # once (or maybe a few times). This simplifies the implementation and allows us + # to save memory in the long run since we do not need to store an additional + # cache for the sparsity detection that would be used only once but carried + # around forever. + return similar(dc.du, T) end end diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index fd78e41..bcfe5a1 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -107,8 +107,12 @@ end Builds a `DiffCache` object that stores both a version of the cache for `u` and for the `Dual` version of `u`, allowing use of pre-cached vectors with -forward-mode automatic differentiation. Supports nested AD via keyword `levels` -or specifying an array of chunk_sizes. +forward-mode automatic differentiation via +[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). +Supports nested AD via keyword `levels` or specifying an array of chunk sizes. + +The `DiffCache` also supports sparsity detection via +[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/). """ function DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u)); levels::Int = 1) diff --git a/test/runtests.jl b/test/runtests.jl index 369a492..0ae51df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "DiffCache Resizing" include("core_resizing.jl") @safetestset "DiffCache Nested Duals" include("core_nesteddual.jl") @safetestset "DiffCache Sparsity Support" include("sparsity_support.jl") + @safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl") @safetestset "LazyBufferCache" include("lbc.jl") @safetestset "GeneralLazyBufferCache" include("general_lbc.jl") end diff --git a/test/sparse_connectivity_tracer.jl b/test/sparse_connectivity_tracer.jl new file mode 100644 index 0000000..9bbf5d3 --- /dev/null +++ b/test/sparse_connectivity_tracer.jl @@ -0,0 +1,63 @@ +module TestSparseConnectivityTracer + +using PreallocationTools, SparseConnectivityTracer, ForwardDiff, SparseArrays, Test + +function f1(u, cache) + c = get_tmp(cache, u) + # This will throw if a fallback definition is used + # such that `eltype(c) == Any` + T = eltype(c) + @. c = u^2 + one(T) + return sum(c) +end + +@testset "out of place" begin + u = rand(10) + cache = DiffCache(u) + + @test_nowarn @inferred f1(u, cache) + @test_nowarn ForwardDiff.gradient(u) do u + f1(u, cache) + end + @test_nowarn jacobian_sparsity(u, TracerSparsityDetector()) do u + f1(u, cache) + end + @test_nowarn hessian_sparsity(u, TracerSparsityDetector()) do u + f1(u, cache) + end + @test_nowarn jacobian_sparsity(u, TracerLocalSparsityDetector()) do u + f1(u, cache) + end + @test_nowarn hessian_sparsity(u, TracerLocalSparsityDetector()) do u + f1(u, cache) + end +end + +function f1!(du, u, cache) + c = get_tmp(cache, u) + # This will throw if a fallback definition is used + # such that `eltype(c) == Any` + T = eltype(c) + @. c = u^2 + one(T) + du[1] = sum(c) + return nothing +end + +@testset "in place" begin + u = rand(10) + cache = DiffCache(u) + du = similar(u, (1,)) + + @test_nowarn @inferred f1!(du, u, cache) + @test_nowarn ForwardDiff.jacobian(du, u) do du, u + f1!(du, u, cache) + end + @test_nowarn jacobian_sparsity(du, u, TracerSparsityDetector()) do du, u + f1!(du, u, cache) + end + @test_nowarn jacobian_sparsity(du, u, TracerLocalSparsityDetector()) do du, u + f1!(du, u, cache) + end +end + +end