Skip to content

Commit 294ec09

Browse files
committed
feat: using DI for structured Jacobians
1 parent fb8b1ee commit 294ec09

File tree

3 files changed

+22
-103
lines changed

3 files changed

+22
-103
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
3030
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
3131
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3232
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
33-
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
3433
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
3534
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3635
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
@@ -146,6 +145,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
146145
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
147146
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
148147
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
148+
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
149149
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
150150
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
151151
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -155,4 +155,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
155155
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
156156

157157
[targets]
158-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]
158+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseDiffTools", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]

src/NonlinearSolve.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ using SciMLJacobianOperators: AbstractJacobianOperator, JacobianOperator, VecJac
5454
## Sparse AD Support
5555
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
5656
using SparseConnectivityTracer: TracerSparsityDetector # This can be dropped in the next release
57-
using SparseDiffTools: SparseDiffTools, JacPrototypeSparsityDetection,
58-
PrecomputedJacobianColorvec, init_jacobian, sparse_jacobian,
59-
sparse_jacobian!, sparse_jacobian_cache
6057
using SparseMatrixColorings: ConstantColoringAlgorithm, GreedyColoringAlgorithm,
6158
LargestFirst
6259

src/internal/jacobian.jl

Lines changed: 20 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`.
3737
stats::NLStats
3838
autodiff
3939
di_extras
40-
sdifft_extras
4140
end
4241

4342
function reinit_cache!(cache::JacobianCache{iip}, args...; p = cache.p,
@@ -63,31 +62,13 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
6362

6463
if !has_analytic_jac && needs_jac
6564
autodiff = construct_concrete_adtype(f, autodiff)
66-
using_sparsedifftools = autodiff isa StructuredMatrixAutodiff
67-
# SparseMatrixColorings can't handle structured matrices
68-
if using_sparsedifftools
69-
di_extras = nothing
70-
uf = JacobianWrapper{iip}(f, p)
71-
sdifft_extras = if iip
72-
sparse_jacobian_cache(
73-
autodiff.autodiff, autodiff.sparsity_detection, uf, fu, u)
74-
else
75-
sparse_jacobian_cache(autodiff.autodiff, autodiff.sparsity_detection,
76-
uf, __maybe_mutable(u, autodiff); fx = fu)
77-
end
78-
autodiff = autodiff.autodiff # For saving we unwrap
65+
di_extras = if iip
66+
DI.prepare_jacobian(f, fu, autodiff, u, Constant(prob.p))
7967
else
80-
sdifft_extras = nothing
81-
di_extras = if iip
82-
DI.prepare_jacobian(f, fu, autodiff, u, Constant(prob.p))
83-
else
84-
DI.prepare_jacobian(f, autodiff, u, Constant(prob.p))
85-
end
68+
DI.prepare_jacobian(f, autodiff, u, Constant(prob.p))
8669
end
8770
else
88-
using_sparsedifftools = false
8971
di_extras = nothing
90-
sdifft_extras = nothing
9172
end
9273

9374
J = if !needs_jac
@@ -98,22 +79,18 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
9879
JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff)
9980
else
10081
if f.jac_prototype === nothing
101-
if !using_sparsedifftools
102-
# While this is technically wasteful, it gives out the type of the Jacobian
103-
# which is needed to create the linear solver cache
104-
stats.njacs += 1
105-
if has_analytic_jac
106-
__similar(
107-
fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
82+
# While this is technically wasteful, it gives out the type of the Jacobian
83+
# which is needed to create the linear solver cache
84+
stats.njacs += 1
85+
if has_analytic_jac
86+
__similar(
87+
fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
88+
else
89+
if iip
90+
DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p))
10891
else
109-
if iip
110-
DI.jacobian(f, fu, di_extras, autodiff, u, Constant(p))
111-
else
112-
DI.jacobian(f, autodiff, u, Constant(p))
113-
end
92+
DI.jacobian(f, autodiff, u, Constant(p))
11493
end
115-
else
116-
zero(init_jacobian(sdifft_extras; preserve_immutable = Val(true)))
11794
end
11895
else
11996
jac_proto = if eltype(f.jac_prototype) <: Bool
@@ -126,20 +103,19 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
126103
end
127104
end
128105

129-
return JacobianCache{iip}(
130-
J, f, fu, u, p, stats, autodiff, di_extras, sdifft_extras)
106+
return JacobianCache{iip}(J, f, fu, u, p, stats, autodiff, di_extras)
131107
end
132108

133109
function JacobianCache(prob, alg, f::F, ::Number, u::Number, p; stats,
134110
autodiff = nothing, kwargs...) where {F}
135111
fu = f(u, p)
136112
if SciMLBase.has_jac(f) || SciMLBase.has_vjp(f) || SciMLBase.has_jvp(f)
137-
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, nothing, nothing)
113+
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, nothing)
138114
end
139115
autodiff = get_dense_ad(get_concrete_forward_ad(
140116
autodiff, prob; check_forward_mode = false))
141117
di_extras = DI.prepare_derivative(f, get_dense_ad(autodiff), u, Constant(prob.p))
142-
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, di_extras, nothing)
118+
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, di_extras)
143119
end
144120

145121
(cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
@@ -172,27 +148,16 @@ function (cache::JacobianCache{iip})(
172148
if iip
173149
if SciMLBase.has_jac(cache.f)
174150
cache.f.jac(J, u, p)
175-
elseif cache.di_extras !== nothing
151+
else
176152
DI.jacobian!(
177153
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p))
178-
else
179-
uf = JacobianWrapper{iip}(cache.f, p)
180-
sparse_jacobian!(J, cache.autodiff, cache.sdifft_extras, uf, cache.fu, u)
181154
end
182155
return J
183156
else
184157
if SciMLBase.has_jac(cache.f)
185158
return cache.f.jac(u, p)
186-
elseif cache.di_extras !== nothing
187-
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
188159
else
189-
uf = JacobianWrapper{iip}(cache.f, p)
190-
if __can_setindex(typeof(J))
191-
sparse_jacobian!(J, cache.autodiff, cache.sdifft_extras, uf, u)
192-
return J
193-
else
194-
return sparse_jacobian(cache.autodiff, cache.sdifft_extras, uf, u)
195-
end
160+
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
196161
end
197162
end
198163
end
@@ -207,10 +172,6 @@ function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType)
207172
end
208173
return ad # No sparse AD
209174
else
210-
if ArrayInterface.isstructured(f.jac_prototype)
211-
return select_fastest_structured_matrix_autodiff(f.jac_prototype, f, ad)
212-
end
213-
214175
return AutoSparse(
215176
ad;
216177
sparsity_detector = KnownJacobianSparsityDetector(f.jac_prototype),
@@ -227,10 +188,6 @@ function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType)
227188
Base.depwarn("`sparsity::typeof($(typeof(f.sparsity)))` is deprecated. \
228189
Pass it as `jac_prototype` instead.",
229190
:NonlinearSolve)
230-
if ArrayInterface.isstructured(f.sparsity)
231-
return select_fastest_structured_matrix_autodiff(f.sparsity, f, ad)
232-
end
233-
234191
return AutoSparse(
235192
ad;
236193
sparsity_detector = KnownJacobianSparsityDetector(f.sparsity),
@@ -252,11 +209,8 @@ function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType)
252209
coloring_algorithm = GreedyColoringAlgorithm(LargestFirst())
253210
)
254211
else
255-
if ArrayInterface.isstructured(f.jac_prototype)
256-
return select_fastest_structured_matrix_autodiff(f.jac_prototype, f, ad)
257-
end
258-
259-
if f.jac_prototype isa AbstractSparseMatrix
212+
if f.jac_prototype isa AbstractSparseMatrix ||
213+
ArrayInterface.isstructured(f.jac_prototype)
260214
if !(sparsity_detector isa NoSparsityDetector)
261215
@warn "`jac_prototype` is a sparse matrix but sparsity = $(f.sparsity) \
262216
has also been specified. Ignoring sparsity field and using \
@@ -275,38 +229,6 @@ function construct_concrete_adtype(f::NonlinearFunction, ad::AbstractADType)
275229
end
276230
end
277231

278-
@concrete struct StructuredMatrixAutodiff <: AbstractADType
279-
autodiff <: AbstractADType
280-
sparsity_detection
281-
end
282-
283-
function select_fastest_structured_matrix_autodiff(
284-
prototype::AbstractMatrix, f::NonlinearFunction, ad::AbstractADType)
285-
sparsity_detection = if SciMLBase.has_colorvec(f)
286-
PrecomputedJacobianColorvec(;
287-
jac_prototype = prototype,
288-
f.colorvec,
289-
partition_by_rows = ADTypes.mode(ad) isa ADTypes.ReverseMode
290-
)
291-
else
292-
if ArrayInterface.fast_matrix_colors(prototype)
293-
colorvec = if ADTypes.mode(ad) isa ADTypes.ForwardMode
294-
ArrayInterface.matrix_colors(prototype)
295-
else
296-
ArrayInterface.matrix_colors(prototype')
297-
end
298-
PrecomputedJacobianColorvec(;
299-
jac_prototype = prototype,
300-
colorvec,
301-
partition_by_rows = ADTypes.mode(ad) isa ADTypes.ReverseMode
302-
)
303-
else
304-
JacPrototypeSparsityDetection(; jac_prototype = prototype)
305-
end
306-
end
307-
return StructuredMatrixAutodiff(AutoSparse(ad), sparsity_detection)
308-
end
309-
310232
function select_fastest_coloring_algorithm(
311233
prototype, f::NonlinearFunction, ad::AbstractADType)
312234
if SciMLBase.has_colorvec(f)

0 commit comments

Comments
 (0)