Skip to content

Commit ef2bcb5

Browse files
authored
AutoSparse should only support Jacobians and Hessians (#277)
* AutoSparse only does Jacobians and Hessians * Use dense backend explicitly * Re-export checks * Typo * Fix scenarios * Fix tests * Type stab * Types * Zygote on GPU * Sparsity on 1.6 * Typo * Typo * Doc
1 parent 170a729 commit ef2bcb5

25 files changed

+271
-240
lines changed

DifferentiationInterface/Project.toml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.4.2"
4+
version = "0.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -73,7 +73,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
7373
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7474
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7575
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
76-
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
76+
# DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
7777
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
7878
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
7979
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
@@ -95,4 +95,16 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9595
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9696

9797
[targets]
98-
test = ["ADTypes", "Aqua", "DataFrames", "DifferentiationInterfaceTest", "JET", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "Test"]
98+
test = [
99+
"ADTypes",
100+
"Aqua",
101+
"DataFrames",
102+
# "DifferentiationInterfaceTest",
103+
"JET",
104+
"JuliaFormatter",
105+
"Pkg",
106+
"SparseArrays",
107+
"SparseConnectivityTracer",
108+
"SparseMatrixColorings",
109+
"Test",
110+
]

DifferentiationInterface/docs/src/api.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ hessian!
101101
check_available
102102
check_twoarg
103103
check_hessian
104+
DifferentiationInterface.outer
105+
DifferentiationInterface.inner
104106
```
105107

106108
### Backend switch
@@ -116,4 +118,5 @@ The following is not part of the public API.
116118
```@autodocs
117119
Modules = [DifferentiationInterface]
118120
Public = false
121+
Filter = t -> !(Symbol(t) in [:outer, :inner])
119122
```

DifferentiationInterface/docs/src/operators.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,11 @@ backend = SecondOrder(outer_backend, inner_backend)
148148

149149
The inner backend will be called first, and the outer backend will differentiate the generated code.
150150

151-
!!! warning
152-
There are many possible backend combinations, a lot of which will fail.
153-
Usually, the most efficient approach for Hessians is forward-over-reverse, i.e. a forward-mode outer backend and a reverse-mode inner backend.
151+
There are many possible backend combinations, a lot of which will fail.
152+
Usually, the most efficient approach for Hessians is forward-over-reverse, i.e. a forward-mode outer backend and a reverse-mode inner backend.
153+
154+
!!! danger
155+
`SecondOrder` backends do not support first-order operators.
154156

155157
!!! warning
156158
Preparation does not yet work for the inner differentiation step of a `SecondOrder`, only the outer differentiation is prepared.
@@ -164,23 +166,22 @@ For this to work, three ingredients are needed (read [this survey](https://epubs
164166
2. A sparsity pattern detector like [`TracerSparsityDetector`](@extref SparseConnectivityTracer.TracerSparsityDetector) from [SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl)
165167
3. A coloring algorithm like [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl)
166168

167-
These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which Differentiation.jl re-exports.
169+
These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which DifferentiationInterface.jl re-exports.
168170
Note that for sparse Hessians, you need to put the `SecondOrder` backend inside `AutoSparse`, and not the other way around.
169171

170172
The preparation step of `jacobian` or `hessian` with an `AutoSparse` backend can be long, because it needs to detect the sparsity pattern and color the resulting sparse matrix.
171173
But after preparation, the more zeros are present in the matrix, the greater the speedup will be compared to dense differentiation.
172174

175+
!!! danger
176+
`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants).
177+
173178
!!! warning
174179
The result of preparation for an `AutoSparse` backend cannot be reused if the sparsity pattern changes.
175180

176181
!!! info
177-
The symbolic backends have built-in sparsity handling, so `AutoSparse(AutoSymbolics())` and `AutoSparse(AutoFastDifferentiation())` do not need additional configuration for pattern detection or coloring.
182+
Symbolic backends have built-in sparsity handling, so `AutoSparse(AutoSymbolics())` and `AutoSparse(AutoFastDifferentiation())` do not need additional configuration for pattern detection or coloring.
178183
However they still benefit from preparation.
179184

180-
!!! warning
181-
At the moment, `AutoSparse` backends can be used with operators other than `jacobian` and `hessian`.
182-
This possibility will be removed in the next breaking release.
183-
184185
## Going further
185186

186187
### Non-standard types

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ function __init__()
7777
@require_extensions
7878
end
7979

80+
## Exported
81+
8082
export SecondOrder
8183

8284
export value_and_pushforward!, value_and_pushforward
@@ -107,9 +109,8 @@ export check_available, check_twoarg, check_hessian
107109

108110
export DifferentiateWith
109111

110-
export GreedyColoringAlgorithm
112+
## Re-exported from ADTypes
111113

112-
# Re-export backends from ADTypes
113114
export AutoChainRules
114115
export AutoDiffractor
115116
export AutoEnzyme
@@ -126,4 +127,12 @@ export AutoZygote
126127

127128
export AutoSparse
128129

130+
## Re-exported from SparseMatrixColorings
131+
132+
export GreedyColoringAlgorithm
133+
134+
## Public but not exported
135+
136+
@compat public inner, outer
137+
129138
end # module

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ end
5050
function hessian(
5151
f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x)
5252
) where {F}
53-
new_backend = SecondOrder(backend)
53+
new_backend = SecondOrder(backend, backend)
5454
new_extras = prepare_hessian(f, new_backend, x)
5555
return hessian(f, new_backend, x, new_extras)
5656
end
@@ -75,7 +75,7 @@ function hessian!(
7575
x,
7676
extras::HessianExtras=prepare_hessian(f, backend, x),
7777
) where {F}
78-
new_backend = SecondOrder(backend)
78+
new_backend = SecondOrder(backend, backend)
7979
new_extras = prepare_hessian(f, new_backend, x)
8080
return hessian!(f, hess, new_backend, x, new_extras)
8181
end

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ end
136136
function hvp(
137137
f::F, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v)
138138
) where {F}
139-
new_backend = SecondOrder(backend)
139+
new_backend = SecondOrder(backend, backend)
140140
new_extras = prepare_hvp(f, new_backend, x, v)
141141
return hvp(f, new_backend, x, v, new_extras)
142142
end
@@ -175,7 +175,7 @@ end
175175
function hvp!(
176176
f::F, p, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v)
177177
) where {F}
178-
new_backend = SecondOrder(backend)
178+
new_backend = SecondOrder(backend, backend)
179179
new_extras = prepare_hvp(f, new_backend, x, v)
180180
return hvp!(f, p, new_backend, x, v, new_extras)
181181
end

DifferentiationInterface/src/second_order/second_order.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
44
Combination of two backends for second-order differentiation.
55
6+
!!! danger
7+
`SecondOrder` backends do not support first-order operators.
8+
69
# Constructor
710
8-
SecondOrder(outer, inner)
11+
SecondOrder(outer_backend, inner_backend)
912
1013
# Fields
1114
@@ -18,26 +21,27 @@ struct SecondOrder{ADO<:AbstractADType,ADI<:AbstractADType} <: AbstractADType
1821
inner::ADI
1922
end
2023

21-
SecondOrder(backend::AbstractADType) = SecondOrder(backend, backend)
22-
23-
inner(backend::SecondOrder) = backend.inner
24-
outer(backend::SecondOrder) = backend.outer
25-
2624
function Base.show(io::IO, backend::SecondOrder)
2725
return print(io, "SecondOrder($(outer(backend)) / $(inner(backend)))")
2826
end
2927

28+
"""
29+
inner(backend::SecondOrder)
30+
31+
Return the inner backend of a [`SecondOrder`](@ref) object, tasked with differentiation at the first order.
32+
"""
33+
inner(backend::SecondOrder) = backend.inner
34+
35+
"""
36+
outer(backend::SecondOrder)
37+
38+
Return the outer backend of a [`SecondOrder`](@ref) object, tasked with differentiation at the second order.
39+
"""
40+
outer(backend::SecondOrder) = backend.outer
41+
3042
"""
3143
mode(backend::SecondOrder)
3244
3345
Return the _outer_ mode of the second-order backend.
3446
"""
3547
ADTypes.mode(backend::SecondOrder) = mode(outer(backend))
36-
37-
function twoarg_support(backend::SecondOrder)
38-
if Bool(twoarg_support(inner(backend))) && Bool(twoarg_support(outer(backend)))
39-
return TwoArgSupported()
40-
else
41-
return TwoArgNotSupported()
42-
end
43-
end
Lines changed: 5 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,5 @@
1-
## Traits
2-
3-
for trait in (
4-
:check_available,
5-
:twoarg_support,
6-
:pushforward_performance,
7-
:pullback_performance,
8-
:hvp_mode,
9-
)
10-
@eval $trait(backend::AutoSparse) = $trait(dense_ad(backend))
11-
end
12-
13-
## Operators
14-
15-
for op in (:pushforward, :pullback, :hvp)
16-
op! = Symbol(op, "!")
17-
valop = Symbol("value_and_", op)
18-
valop! = Symbol("value_and_", op, "!")
19-
prep = Symbol("prepare_", op)
20-
prepsame = Symbol("prepare_", op, "_same_point")
21-
E = if op == :pushforward
22-
:PushforwardExtras
23-
elseif op == :pullback
24-
:PullbackExtras
25-
elseif op == :hvp
26-
:HVPExtras
27-
end
28-
29-
## One argument
30-
@eval begin
31-
$prep(f::F, ba::AutoSparse, x, v) where {F} = $prep(f, dense_ad(ba), x, v)
32-
$prepsame(f::F, ba::AutoSparse, x, v) where {F} = $prepsame(f, dense_ad(ba), x, v)
33-
$prepsame(f::F, ba::AutoSparse, x, v, ex::$E) where {F} =
34-
$prepsame(f, dense_ad(ba), x, v, ex)
35-
$op(f::F, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} =
36-
$op(f, dense_ad(ba), x, v, ex)
37-
$valop(f::F, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} =
38-
$valop(f, dense_ad(ba), x, v, ex)
39-
$op!(f::F, res, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} =
40-
$op!(f, res, dense_ad(ba), x, v, ex)
41-
$valop!(f::F, res, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} =
42-
$valop!(f, res, dense_ad(ba), x, v, ex)
43-
end
44-
45-
## Two arguments
46-
@eval begin
47-
$prep(f!::F, y, ba::AutoSparse, x, v) where {F} = $prep(f!, y, dense_ad(ba), x, v)
48-
$prepsame(f!::F, y, ba::AutoSparse, x, v) where {F} =
49-
$prepsame(f!, y, dense_ad(ba), x, v)
50-
$prepsame(f!::F, y, ba::AutoSparse, x, v, ex::$E) where {F} =
51-
$prepsame(f!, y, dense_ad(ba), x, v, ex)
52-
$op(f!::F, y, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) where {F} =
53-
$op(f!, y, dense_ad(ba), x, v, ex)
54-
$valop(f!::F, y, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) where {F} =
55-
$valop(f!, y, dense_ad(ba), x, v, ex)
56-
$op!(f!::F, y, res, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) where {F} =
57-
$op!(f!, y, res, dense_ad(ba), x, v, ex)
58-
$valop!(
59-
f!::F, y, res, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)
60-
) where {F} = $valop!(f!, y, res, dense_ad(ba), x, v, ex)
61-
end
62-
end
63-
64-
for op in (:derivative, :gradient, :second_derivative)
65-
op! = Symbol(op, "!")
66-
valop = Symbol("value_and_", op)
67-
valop! = Symbol("value_and_", op, "!")
68-
prep = Symbol("prepare_", op)
69-
E = if op == :derivative
70-
:DerivativeExtras
71-
elseif op == :gradient
72-
:GradientExtras
73-
elseif op == :second_derivative
74-
:SecondDerivativeExtras
75-
end
76-
77-
## One argument
78-
@eval begin
79-
$prep(f::F, ba::AutoSparse, x) where {F} = $prep(f, dense_ad(ba), x)
80-
$op(f::F, ba::AutoSparse, x, ex::$E=$prep(f, ba, x)) where {F} =
81-
$op(f, dense_ad(ba), x, ex)
82-
$valop(f::F, ba::AutoSparse, x, ex::$E=$prep(f, ba, x)) where {F} =
83-
$valop(f, dense_ad(ba), x, ex)
84-
$op!(f::F, res, ba::AutoSparse, x, ex::$E=$prep(f, ba, x)) where {F} =
85-
$op!(f, res, dense_ad(ba), x, ex)
86-
$valop!(f::F, res, ba::AutoSparse, x, ex::$E=$prep(f, ba, x)) where {F} =
87-
$valop!(f, res, dense_ad(ba), x, ex)
88-
end
89-
90-
## Two arguments
91-
if op in (:derivative,)
92-
@eval begin
93-
$prep(f!::F, y, ba::AutoSparse, x) where {F} = $prep(f!, y, dense_ad(ba), x)
94-
$op(f!::F, y, ba::AutoSparse, x, ex::$E=$prep(f!, y, ba, x)) where {F} =
95-
$op(f!, y, dense_ad(ba), x, ex)
96-
$valop(f!::F, y, ba::AutoSparse, x, ex::$E=$prep(f!, y, ba, x)) where {F} =
97-
$valop(f!, y, dense_ad(ba), x, ex)
98-
$op!(f!::F, y, res, ba::AutoSparse, x, ex::$E=$prep(f!, y, ba, x)) where {F} =
99-
$op!(f!, y, res, dense_ad(ba), x, ex)
100-
$valop!(
101-
f!::F, y, res, ba::AutoSparse, x, ex::$E=$prep(f!, y, ba, x)
102-
) where {F} = $valop!(f!, y, res, dense_ad(ba), x, ex)
103-
end
104-
end
105-
end
1+
check_available(backend::AutoSparse) = check_available(dense_ad(backend))
2+
twoarg_support(backend::AutoSparse) = twoarg_support(dense_ad(backend))
3+
pushforward_performance(backend::AutoSparse) = pushforward_performance(dense_ad(backend))
4+
pullback_performance(backend::AutoSparse) = pullback_performance(dense_ad(backend))
5+
hvp_mode(backend::AutoSparse{<:SecondOrder}) = hvp_mode(dense_ad(backend))

DifferentiationInterface/src/sparse/hessian.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ end
1717
## Hessian, one argument
1818

1919
function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
20+
dense_backend = dense_ad(backend)
2021
initial_sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
2122
sparsity = col_major(initial_sparsity)
2223
colors = symmetric_coloring(sparsity, coloring_algorithm(backend))
@@ -26,7 +27,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
2627
seed[group] .= one(eltype(x))
2728
seed
2829
end
29-
hvp_extras = prepare_hvp(f, backend, x, first(seeds))
30+
hvp_extras = prepare_hvp(f, dense_backend, x, first(seeds))
3031
products = map(seeds) do _
3132
similar(x)
3233
end
@@ -36,9 +37,10 @@ end
3637

3738
function hessian!(f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras) where {F}
3839
@compat (; sparsity, compressed, colors, seeds, products, hvp_extras) = extras
39-
hvp_extras_same = prepare_hvp_same_point(f, backend, x, seeds[1], hvp_extras)
40+
dense_backend = dense_ad(backend)
41+
hvp_extras_same = prepare_hvp_same_point(f, dense_backend, x, seeds[1], hvp_extras)
4042
for k in eachindex(seeds, products)
41-
hvp!(f, products[k], backend, x, seeds[k], hvp_extras_same)
43+
hvp!(f, products[k], dense_backend, x, seeds[k], hvp_extras_same)
4244
copyto!(view(compressed, :, k), vec(products[k]))
4345
end
4446
decompress_symmetric!(hess, sparsity, compressed, colors)
@@ -47,9 +49,10 @@ end
4749

4850
function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras) where {F}
4951
@compat (; sparsity, compressed, colors, seeds, products, hvp_extras) = extras
50-
hvp_extras_same = prepare_hvp_same_point(f, backend, x, seeds[1], hvp_extras)
52+
dense_backend = dense_ad(backend)
53+
hvp_extras_same = prepare_hvp_same_point(f, dense_backend, x, seeds[1], hvp_extras)
5154
compressed = stack(eachindex(seeds, products); dims=2) do k
52-
vec(hvp(f, backend, x, seeds[k], hvp_extras_same))
55+
vec(hvp(f, dense_backend, x, seeds[k], hvp_extras_same))
5356
end
5457
return decompress_symmetric(sparsity, compressed, colors)
5558
end

0 commit comments

Comments
 (0)