Skip to content

Commit 36fdc6f

Browse files
Merge pull request #122 from ranocha/hr/SparseConnectivityTracer
support SparseConnectivityTracer.jl
2 parents 94ccad1 + 8406b13 commit 36fdc6f

File tree

5 files changed

+92
-39
lines changed

5 files changed

+92
-39
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,21 @@ function that is being called.
2020

2121
`DiffCache` is a type for doubly-preallocated vectors which are
2222
compatible with non-allocating forward-mode automatic differentiation by
23-
ForwardDiff.jl. Since ForwardDiff uses chunked duals in its forward pass, two
23+
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
24+
Since ForwardDiff.jl uses chunked duals in its forward pass, two
2425
vector sizes are required in order for the arrays to be properly defined.
2526
`DiffCache` creates a dispatching type to solve this, so that by passing a
2627
qualifier it can automatically switch between the required cache. This method
2728
is fully type-stable and non-dynamic, made for when the highest performance is
2829
needed.
2930

31+
The `DiffCache` also supports sparsity detection via
32+
[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/).
33+
However, the implementation may allocate memory in this case since we assume that
34+
sparsity detection happens only once (or maybe a few times). Allocating memory
35+
allows to save memory in the long run since no additional cache needs to be stored
36+
forever.
37+
3038
### Using DiffCache
3139

3240
```julia
Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,23 @@
11
module PreallocationToolsSparseConnectivityTracerExt
22

3-
using PreallocationTools
4-
isdefined(Base, :get_extension) ? (import SparseConnectivityTracer) :
5-
(import ..SparseConnectivityTracer)
3+
using PreallocationTools: PreallocationTools, DiffCache, get_tmp
4+
using SparseConnectivityTracer: AbstractTracer, Dual
65

7-
function PreallocationTools.get_tmp(
8-
dc::DiffCache, u::T) where {T <: SparseConnectivityTracer.Dual}
9-
if isbitstype(T)
10-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
11-
if nelem > length(dc.dual_du)
12-
PreallocationTools.enlargediffcache!(dc, nelem)
13-
end
14-
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
15-
else
16-
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
17-
end
6+
function PreallocationTools.get_tmp(dc::DiffCache, u::T) where {T <: Union{AbstractTracer, Dual}}
7+
return get_tmp(dc, typeof(u))
188
end
199

20-
function PreallocationTools.get_tmp(
21-
dc::DiffCache, ::Type{T}) where {T <: SparseConnectivityTracer.Dual}
22-
if isbitstype(T)
23-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
24-
if nelem > length(dc.dual_du)
25-
PreallocationTools.enlargediffcache!(dc, nelem)
26-
end
27-
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
28-
else
29-
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
30-
end
10+
function PreallocationTools.get_tmp(dc::DiffCache, u::AbstractArray{<:T}) where {T <: Union{AbstractTracer, Dual}}
11+
return get_tmp(dc, eltype(u))
3112
end
3213

33-
function PreallocationTools.get_tmp(
34-
dc::DiffCache, u::AbstractArray{T}) where {T <: SparseConnectivityTracer.Dual}
35-
if isbitstype(T)
36-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
37-
if nelem > length(dc.dual_du)
38-
PreallocationTools.enlargediffcache!(dc, nelem)
39-
end
40-
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
41-
else
42-
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
43-
end
14+
function PreallocationTools.get_tmp(dc::DiffCache, ::Type{T}) where {T <: Union{AbstractTracer, Dual}}
15+
# We allocate memory here since we assume that sparsity connection happens only
16+
# once (or maybe a few times). This simplifies the implementation and allows us
17+
# to save memory in the long run since we do not need to store an additional
18+
# cache for the sparsity detection that would be used only once but carried
19+
# around forever.
20+
return similar(dc.du, T)
4421
end
4522

4623
end

src/PreallocationTools.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,12 @@ end
107107
108108
Builds a `DiffCache` object that stores both a version of the cache for `u`
109109
and for the `Dual` version of `u`, allowing use of pre-cached vectors with
110-
forward-mode automatic differentiation. Supports nested AD via keyword `levels`
111-
or specifying an array of chunk_sizes.
110+
forward-mode automatic differentiation via
111+
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
112+
Supports nested AD via keyword `levels` or specifying an array of chunk sizes.
113+
114+
The `DiffCache` also supports sparsity detection via
115+
[SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl/).
112116
"""
113117
function DiffCache(u::AbstractArray, N::Int = ForwardDiff.pickchunksize(length(u));
114118
levels::Int = 1)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if GROUP == "All" || GROUP == "Core"
1616
@safetestset "DiffCache Resizing" include("core_resizing.jl")
1717
@safetestset "DiffCache Nested Duals" include("core_nesteddual.jl")
1818
@safetestset "DiffCache Sparsity Support" include("sparsity_support.jl")
19+
@safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl")
1920
@safetestset "LazyBufferCache" include("lbc.jl")
2021
@safetestset "GeneralLazyBufferCache" include("general_lbc.jl")
2122
end

test/sparse_connectivity_tracer.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
module TestSparseConnectivityTracer
2+
3+
using PreallocationTools, SparseConnectivityTracer, ForwardDiff, SparseArrays, Test
4+
5+
function f1(u, cache)
6+
c = get_tmp(cache, u)
7+
# This will throw if a fallback definition is used
8+
# such that `eltype(c) == Any`
9+
T = eltype(c)
10+
@. c = u^2 + one(T)
11+
return sum(c)
12+
end
13+
14+
@testset "out of place" begin
15+
u = rand(10)
16+
cache = DiffCache(u)
17+
18+
@test_nowarn @inferred f1(u, cache)
19+
@test_nowarn ForwardDiff.gradient(u) do u
20+
f1(u, cache)
21+
end
22+
@test_nowarn jacobian_sparsity(u, TracerSparsityDetector()) do u
23+
f1(u, cache)
24+
end
25+
@test_nowarn hessian_sparsity(u, TracerSparsityDetector()) do u
26+
f1(u, cache)
27+
end
28+
@test_nowarn jacobian_sparsity(u, TracerLocalSparsityDetector()) do u
29+
f1(u, cache)
30+
end
31+
@test_nowarn hessian_sparsity(u, TracerLocalSparsityDetector()) do u
32+
f1(u, cache)
33+
end
34+
end
35+
36+
function f1!(du, u, cache)
37+
c = get_tmp(cache, u)
38+
# This will throw if a fallback definition is used
39+
# such that `eltype(c) == Any`
40+
T = eltype(c)
41+
@. c = u^2 + one(T)
42+
du[1] = sum(c)
43+
return nothing
44+
end
45+
46+
@testset "in place" begin
47+
u = rand(10)
48+
cache = DiffCache(u)
49+
du = similar(u, (1,))
50+
51+
@test_nowarn @inferred f1!(du, u, cache)
52+
@test_nowarn ForwardDiff.jacobian(du, u) do du, u
53+
f1!(du, u, cache)
54+
end
55+
@test_nowarn jacobian_sparsity(du, u, TracerSparsityDetector()) do du, u
56+
f1!(du, u, cache)
57+
end
58+
@test_nowarn jacobian_sparsity(du, u, TracerLocalSparsityDetector()) do du, u
59+
f1!(du, u, cache)
60+
end
61+
end
62+
63+
end

0 commit comments

Comments
 (0)