Skip to content

Commit 864f215

Browse files
authored
fix: support Cache in sparsity detection with SparseConnectivityTracer (#739)
* fix: support `Cache` in sparsity detection with SparseConnectivityTracer * Dead import
1 parent d29659e commit 864f215

File tree

10 files changed

+109
-29
lines changed

10 files changed

+109
-29
lines changed

DifferentiationInterface/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.44"
4+
version = "0.6.45"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -22,6 +22,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2222
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2323
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2424
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
25+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
2526
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
2627
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2728
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -41,6 +42,7 @@ DifferentiationInterfaceMooncakeExt = "Mooncake"
4142
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
4243
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4344
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
45+
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
4446
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
4547
DifferentiationInterfaceStaticArraysExt = "StaticArrays"
4648
DifferentiationInterfaceSymbolicsExt = "Symbolics"
@@ -66,7 +68,7 @@ Mooncake = "0.4.88"
6668
PolyesterForwardDiff = "0.1.2"
6769
ReverseDiff = "1.15.1"
6870
SparseArrays = "<0.0.1,1"
69-
SparseConnectivityTracer = "0.5.0,0.6"
71+
SparseConnectivityTracer = "0.6.14"
7072
SparseMatrixColorings = "0.4.9"
7173
StaticArrays = "1.9.7"
7274
Symbolics = "5.27.1, 6"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
module DifferentiationInterfaceSparseConnectivityTracerExt
2+
3+
using ADTypes: jacobian_sparsity, hessian_sparsity
4+
import DifferentiationInterface as DI
5+
using SparseConnectivityTracer:
6+
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer
7+
8+
@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c)
9+
@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray})
10+
return jacobian_buffer(DI.unwrap(c), detector)
11+
end
12+
13+
function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
14+
new_contexts = map(contexts) do c
15+
_jacobian_translate(detector, c)
16+
end
17+
return new_contexts
18+
end
19+
20+
@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c)
21+
@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray})
22+
return hessian_buffer(DI.unwrap(c), detector)
23+
end
24+
25+
function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
26+
new_contexts = map(contexts) do c
27+
_hessian_translate(detector, c)
28+
end
29+
return new_contexts
30+
end
31+
32+
function DI.jacobian_sparsity_with_contexts(
33+
f::F,
34+
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
35+
x,
36+
contexts::Vararg{DI.Context,C},
37+
) where {F,C}
38+
contexts_tracer = jacobian_translate(detector, contexts...)
39+
fc = DI.FixTail(f, contexts_tracer)
40+
return jacobian_sparsity(fc, x, detector)
41+
end
42+
43+
function DI.jacobian_sparsity_with_contexts(
44+
f!::F,
45+
y,
46+
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
47+
x,
48+
contexts::Vararg{DI.Context,C},
49+
) where {F,C}
50+
contexts_tracer = jacobian_translate(detector, contexts...)
51+
fc! = DI.FixTail(f!, contexts_tracer)
52+
return jacobian_sparsity(fc!, y, x, detector)
53+
end
54+
55+
function DI.hessian_sparsity_with_contexts(
56+
f::F,
57+
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
58+
x,
59+
contexts::Vararg{DI.Context,C},
60+
) where {F,C}
61+
contexts_tracer = hessian_translate(detector, contexts...)
62+
fc = DI.FixTail(f, contexts_tracer)
63+
return hessian_sparsity(fc, x, detector)
64+
end
65+
66+
end

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
module DifferentiationInterfaceSparseMatrixColoringsExt
22

3-
using ADTypes:
4-
ADTypes,
5-
AutoSparse,
6-
coloring_algorithm,
7-
dense_ad,
8-
sparsity_detector,
9-
jacobian_sparsity,
10-
hessian_sparsity
3+
using ADTypes: ADTypes, AutoSparse, coloring_algorithm, dense_ad, sparsity_detector
114
import DifferentiationInterface as DI
125
using SparseMatrixColorings:
136
AbstractColoringResult,
@@ -22,14 +15,6 @@ using SparseMatrixColorings:
2215
decompress!
2316
import SparseMatrixColorings as SMC
2417

25-
function fycont(f, contexts::Vararg{DI.Context,C}) where {C}
26-
return (DI.with_contexts(f, contexts...),)
27-
end
28-
29-
function fycont(f!, y, contexts::Vararg{DI.Context,C}) where {C}
30-
return (DI.with_contexts(f!, contexts...), y)
31-
end
32-
3318
abstract type SparseJacobianPrep <: DI.JacobianPrep end
3419

3520
SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result)

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ function DI.prepare_hessian(
2727
f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
2828
) where {F,C}
2929
dense_backend = dense_ad(backend)
30-
sparsity = hessian_sparsity(
31-
DI.with_contexts(f, contexts...), x, sparsity_detector(backend)
30+
sparsity = DI.hessian_sparsity_with_contexts(
31+
f, sparsity_detector(backend), x, contexts...
3232
)
3333
problem = ColoringProblem{:symmetric,:column}()
3434
coloring_result = coloring(

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ function _prepare_sparse_jacobian_aux(
5858
contexts::Vararg{DI.Context,C},
5959
) where {FY,C}
6060
dense_backend = dense_ad(backend)
61-
sparsity = jacobian_sparsity(
62-
fycont(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
61+
sparsity = DI.jacobian_sparsity_with_contexts(
62+
f_or_f!y..., sparsity_detector(backend), x, contexts...
6363
)
6464
if perf isa DI.PushforwardFast
6565
problem = ColoringProblem{:nonsymmetric,:column}()

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ function _prepare_mixed_sparse_jacobian_aux(
4242
y, f_or_f!y::FY, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}
4343
) where {FY,C}
4444
dense_backend = dense_ad(backend)
45-
sparsity = jacobian_sparsity(
46-
fycont(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
45+
sparsity = DI.jacobian_sparsity_with_contexts(
46+
f_or_f!y..., sparsity_detector(backend), x, contexts...
4747
)
4848
problem = ColoringProblem{:nonsymmetric,:bidirectional}()
4949
coloring_result = coloring(

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ module DifferentiationInterface
88
using ADTypes:
99
ADTypes,
1010
AbstractADType,
11+
AbstractSparsityDetector,
1112
AutoSparse,
1213
ForwardMode,
1314
ForwardOrReverseMode,
1415
ReverseMode,
1516
SymbolicMode,
1617
dense_ad,
17-
mode
18+
mode,
19+
jacobian_sparsity,
20+
hessian_sparsity
1821
using ADTypes:
1922
AutoChainRules,
2023
AutoDiffractor,
@@ -45,6 +48,7 @@ include("utils/check.jl")
4548
include("utils/printing.jl")
4649
include("utils/context.jl")
4750
include("utils/linalg.jl")
51+
include("utils/sparse.jl")
4852

4953
include("first_order/pushforward.jl")
5054
include("first_order/pullback.jl")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
function jacobian_sparsity_with_contexts(
2+
f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C}
3+
) where {F,C}
4+
return jacobian_sparsity(with_contexts(f, contexts...), x, detector)
5+
end
6+
7+
function jacobian_sparsity_with_contexts(
8+
f!::F, y, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C}
9+
) where {F,C}
10+
return jacobian_sparsity(with_contexts(f!, contexts...), y, x, detector)
11+
end
12+
13+
function hessian_sparsity_with_contexts(
14+
f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C}
15+
) where {F,C}
16+
return hessian_sparsity(with_contexts(f, contexts...), x, detector)
17+
end

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ end
5757
MyAutoSparse.(
5858
vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2]))
5959
),
60-
sparse_scenarios(; include_constantified=true);
60+
sparse_scenarios(; include_constantified=true, include_cachified=true);
6161
sparsity=true,
6262
logging=LOGGING,
6363
)

DifferentiationInterfaceTest/src/scenarios/sparse.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ end
324324
325325
Create a vector of [`Scenario`](@ref)s with sparse array types, focused on sparse Jacobians and Hessians.
326326
"""
327-
function sparse_scenarios(; band_sizes=[5, 10, 20], include_constantified=false)
327+
function sparse_scenarios(;
328+
band_sizes=[5, 10, 20], include_constantified=false, include_cachified=false
329+
)
328330
x_6 = float.(1:6)
329331
x_2_3 = float.(reshape(1:6, 2, 3))
330332
x_50 = float.(range(1, 2, 50))
@@ -341,6 +343,10 @@ function sparse_scenarios(; band_sizes=[5, 10, 20], include_constantified=false)
341343
append!(scens, squarelinearmap_scenarios(x_50, band_sizes))
342344
append!(scens, squarequadraticform_scenarios(x_50, band_sizes))
343345
end
344-
include_constantified && append!(scens, constantify(scens))
345-
return scens
346+
347+
final_scens = Scenario[]
348+
append!(final_scens, scens)
349+
include_constantified && append!(final_scens, constantify(scens))
350+
include_cachified && append!(final_scens, cachify(scens))
351+
return final_scens
346352
end

0 commit comments

Comments
 (0)