Skip to content

Commit 1cf497c

Browse files
authored
Add DenseSparsityDetector (#297)
* Add DenseSparsityDetector * Doctest * Add local warning * More warnings * Tests with more shapes * Fix matrices * Doc * Coverage
1 parent 23526d8 commit 1cf497c

File tree

6 files changed

+257
-2
lines changed

6 files changed

+257
-2
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.5.2"
4+
version = "0.5.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ DifferentiationInterface.inner
111111
DifferentiateWith
112112
```
113113

114+
### Sparsity detection
115+
116+
```@docs
117+
DenseSparsityDetector
118+
```
119+
114120
## Internals
115121

116122
The following is not part of the public API.

DifferentiationInterface/docs/src/operators.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ For this to work, three ingredients are needed (read [this survey](https://epubs
162162
1. An underlying (dense) backend
163163
2. A sparsity pattern detector like:
164164
- [`TracerSparsityDetector`](@extref SparseConnectivityTracer.TracerSparsityDetector) from [SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl)
165-
- [`SymbolicsSparsityDetector`](https://symbolics.juliasymbolics.org/dev/manual/sparsity_detection/) from [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl)
165+
- [`SymbolicsSparsityDetector`](@extref Symbolics.SymbolicsSparsityDetector) from [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl)
166+
- [`DenseSparsityDetector`](@ref) from DifferentiationInterface.jl (beware that this detector only gives a locally valid pattern)
166167
3. A coloring algorithm like:
167168
- [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl)
168169

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ include("sparse/jacobian.jl")
7272
include("sparse/hessian.jl")
7373

7474
include("misc/differentiate_with.jl")
75+
include("misc/sparsity_detector.jl")
7576

7677
function __init__()
7778
@require_extensions
@@ -108,6 +109,7 @@ export prepare_second_derivative, prepare_hessian
108109
export check_available, check_twoarg, check_hessian
109110

110111
export DifferentiateWith
112+
export DenseSparsityDetector
111113

112114
## Re-exported from ADTypes
113115

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""
2+
DenseSparsityDetector
3+
4+
Sparsity pattern detector satisfying the [detection API](https://sciml.github.io/ADTypes.jl/stable/#Sparse-AD) of [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
5+
6+
The nonzeros in a Jacobian or Hessian are detected by computing the relevant matrix with _dense_ AD, and thresholding the entries with a given tolerance (which can be numerically inaccurate).
7+
8+
!!! warning
9+
This detector can be very slow, and should only be used if its output can be exploited multiple times to compute many sparse matrices.
10+
11+
!!! danger
12+
In general, the sparsity pattern you obtain can depend on the provided input `x`. If you want to reuse the pattern, make sure that it is input-agnostic.
13+
14+
# Fields
15+
16+
- `backend::AbstractADType` is the dense AD backend used under the hood
17+
- `atol::Float64` is the minimum magnitude of a matrix entry to be considered nonzero
18+
19+
# Constructor
20+
21+
DenseSparsityDetector(backend; atol, method=:iterative)
22+
23+
The keyword argument `method::Symbol` can be either:
24+
25+
- `:iterative`: compute the matrix in a sequence of matrix-vector products (memory-efficient)
26+
- `:direct`: compute the matrix all at once (memory-hungry but sometimes faster).
27+
28+
Note that the constructor is type-unstable because `method` ends up being a type parameter of the `DenseSparsityDetector` object (this is not part of the API and might change).
29+
30+
# Examples
31+
32+
```jldoctest detector
33+
using ADTypes, DifferentiationInterface, SparseArrays
34+
import ForwardDiff
35+
36+
detector = DenseSparsityDetector(AutoForwardDiff(); atol=1e-5, method=:direct)
37+
38+
ADTypes.jacobian_sparsity(diff, rand(5), detector)
39+
40+
# output
41+
42+
4×5 SparseMatrixCSC{Bool, Int64} with 8 stored entries:
43+
1 1 ⋅ ⋅ ⋅
44+
⋅ 1 1 ⋅ ⋅
45+
⋅ ⋅ 1 1 ⋅
46+
⋅ ⋅ ⋅ 1 1
47+
```
48+
49+
Sometimes the sparsity pattern is input-dependent:
50+
51+
```jldoctest detector
52+
ADTypes.jacobian_sparsity(x -> [prod(x)], rand(2), detector)
53+
54+
# output
55+
56+
1×2 SparseMatrixCSC{Bool, Int64} with 2 stored entries:
57+
1 1
58+
```
59+
60+
```jldoctest detector
61+
ADTypes.jacobian_sparsity(x -> [prod(x)], [0, 1], detector)
62+
63+
# output
64+
65+
1×2 SparseMatrixCSC{Bool, Int64} with 1 stored entry:
66+
1 ⋅
67+
```
68+
"""
69+
struct DenseSparsityDetector{method,B} <: ADTypes.AbstractSparsityDetector
70+
backend::B
71+
atol::Float64
72+
end
73+
74+
function Base.show(io::IO, detector::DenseSparsityDetector{method}) where {method}
75+
@compat (; backend, atol) = detector
76+
return print(io, "DenseSparsityDetector{:$method}($backend; atol=$atol)")
77+
end
78+
79+
function DenseSparsityDetector(
80+
backend::AbstractADType; atol::Float64, method::Symbol=:iterative
81+
)
82+
if !(method in (:iterative, :direct))
83+
throw(
84+
ArgumentError("The keyword `method` must be either `:iterative` or `:direct`.")
85+
)
86+
end
87+
return DenseSparsityDetector{method,typeof(backend)}(backend, atol)
88+
end
89+
90+
## Direct
91+
92+
function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:direct})
93+
@compat (; backend, atol) = detector
94+
J = jacobian(f, backend, x)
95+
return sparse(abs.(J) .> atol)
96+
end
97+
98+
function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:direct})
99+
@compat (; backend, atol) = detector
100+
J = jacobian(f!, y, backend, x)
101+
return sparse(abs.(J) .> atol)
102+
end
103+
104+
function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:direct})
105+
@compat (; backend, atol) = detector
106+
H = hessian(f, backend, x)
107+
return sparse(abs.(H) .> atol)
108+
end
109+
110+
## Iterative
111+
112+
function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterative})
113+
@compat (; backend, atol) = detector
114+
y = f(x)
115+
n, m = length(x), length(y)
116+
I, J = Int[], Int[]
117+
if pushforward_performance(backend) isa PushforwardFast
118+
p = similar(y)
119+
extras = prepare_pushforward_same_point(
120+
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
121+
)
122+
for (kj, j) in enumerate(CartesianIndices(x))
123+
pushforward!(f, p, backend, x, basis(backend, x, j), extras)
124+
for ki in LinearIndices(p)
125+
if abs(p[ki]) > atol
126+
push!(I, ki)
127+
push!(J, kj)
128+
end
129+
end
130+
end
131+
else
132+
p = similar(x)
133+
extras = prepare_pullback_same_point(
134+
f, backend, x, basis(backend, y, first(CartesianIndices(y)))
135+
)
136+
for (ki, i) in enumerate(CartesianIndices(y))
137+
pullback!(f, p, backend, x, basis(backend, y, i), extras)
138+
for kj in LinearIndices(p)
139+
if abs(p[kj]) > atol
140+
push!(I, ki)
141+
push!(J, kj)
142+
end
143+
end
144+
end
145+
end
146+
return sparse(I, J, ones(Bool, length(I)), m, n)
147+
end
148+
149+
function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:iterative})
150+
@compat (; backend, atol) = detector
151+
n, m = length(x), length(y)
152+
I, J = Int[], Int[]
153+
if pushforward_performance(backend) isa PushforwardFast
154+
p = similar(y)
155+
extras = prepare_pushforward_same_point(
156+
f!, y, backend, x, basis(backend, x, first(CartesianIndices(x)))
157+
)
158+
for (kj, j) in enumerate(CartesianIndices(x))
159+
pushforward!(f!, y, p, backend, x, basis(backend, x, j), extras)
160+
for ki in LinearIndices(p)
161+
if abs(p[ki]) > atol
162+
push!(I, ki)
163+
push!(J, kj)
164+
end
165+
end
166+
end
167+
else
168+
p = similar(x)
169+
extras = prepare_pullback_same_point(
170+
f!, y, backend, x, basis(backend, y, first(CartesianIndices(y)))
171+
)
172+
for (ki, i) in enumerate(CartesianIndices(y))
173+
pullback!(f!, y, p, backend, x, basis(backend, y, i), extras)
174+
for kj in LinearIndices(p)
175+
if abs(p[kj]) > atol
176+
push!(I, ki)
177+
push!(J, kj)
178+
end
179+
end
180+
end
181+
end
182+
return sparse(I, J, ones(Bool, length(I)), m, n)
183+
end
184+
185+
function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:iterative})
186+
@compat (; backend, atol) = detector
187+
n = length(x)
188+
I, J = Int[], Int[]
189+
p = similar(x)
190+
extras = prepare_hvp_same_point(
191+
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
192+
)
193+
for (kj, j) in enumerate(CartesianIndices(x))
194+
hvp!(f, p, backend, x, basis(backend, x, j), extras)
195+
for ki in LinearIndices(p)
196+
if abs(p[ki]) > atol
197+
push!(I, ki)
198+
push!(J, kj)
199+
end
200+
end
201+
end
202+
return sparse(I, J, ones(Bool, length(I)), n, n)
203+
end
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using ADTypes: jacobian_sparsity, hessian_sparsity
2+
using DifferentiationInterface
3+
using ForwardDiff: ForwardDiff
4+
using Enzyme: Enzyme
5+
using LinearAlgebra
6+
using SparseArrays
7+
using StableRNGs
8+
using Test
9+
10+
rng = StableRNG(63)
11+
12+
const Jc = sprand(rng, Bool, 10, 20, 0.3)
13+
const Hc = sparse(Symmetric(sprand(rng, Bool, 20, 20, 0.3)))
14+
15+
f(x::AbstractVector) = Jc * x
16+
f(x::AbstractMatrix) = reshape(f(vec(x)), (5, 2))
17+
18+
function f!(y, x)
19+
y .= f(x)
20+
return nothing
21+
end
22+
23+
g(x::AbstractVector) = dot(x, Hc, x)
24+
g(x::AbstractMatrix) = g(vec(x))
25+
26+
@testset verbose = true "$(typeof(backend))" for backend in [
27+
AutoEnzyme(; mode=Enzyme.Reverse), AutoForwardDiff()
28+
]
29+
@test_throws ArgumentError DenseSparsityDetector(backend; atol=1e-5, method=:random)
30+
@testset "$method" for method in (:iterative, :direct)
31+
detector = DenseSparsityDetector(backend; atol=1e-5, method)
32+
string(detector)
33+
for (x, y) in ((rand(20), zeros(10)), (rand(2, 10), zeros(5, 2)))
34+
@test Jc == jacobian_sparsity(f, x, detector)
35+
@test Jc == jacobian_sparsity(f!, copy(y), x, detector)
36+
end
37+
if backend isa AutoForwardDiff
38+
for x in (rand(20), rand(2, 10))
39+
@test Hc == hessian_sparsity(g, x, detector)
40+
end
41+
end
42+
end
43+
end

0 commit comments

Comments
 (0)