Skip to content

Commit aa6334b

Browse files
committed
fix: don't ignore sparsity for dense_ad
1 parent 301153d commit aa6334b

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

src/NonlinearSolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ end
77
using Reexport: @reexport
88
using PrecompileTools: @compile_workload, @setup_workload
99

10-
using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff,
11-
AutoZygote, AutoEnzyme, AutoSparse
10+
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
11+
AutoPolyesterForwardDiff, AutoZygote, AutoEnzyme, AutoSparse
1212
# FIXME: deprecated, remove in future
1313
using ADTypes: AutoSparseFiniteDiff, AutoSparseForwardDiff, AutoSparsePolyesterForwardDiff,
1414
AutoSparseZygote

src/internal/jacobian.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
6060

6161
if !has_analytic_jac && needs_jac
6262
autodiff = get_concrete_forward_ad(autodiff, prob; check_forward_mode = false)
63-
sd = __sparsity_detection_alg(f, autodiff)
63+
sd = sparsity_detection_alg(f, autodiff)
6464
sparse_jac = !(sd isa NoSparsityDetection)
6565
# Eventually we want to do everything via DI. But for now, we just do the dense via DI
6666
if sparse_jac
@@ -174,9 +174,13 @@ function (cache::JacobianCache{iip})(
174174
end
175175
end
176176

177-
# Sparsity Detection Choices
178-
@inline __sparsity_detection_alg(_, _) = NoSparsityDetection()
179-
@inline function __sparsity_detection_alg(f::NonlinearFunction, ad::AutoSparse)
177+
function sparsity_detection_alg(f::NonlinearFunction, ad::AbstractADType)
178+
# TODO: Also handle case where colorvec is provided
179+
f.sparsity === nothing && return NoSparsityDetection()
180+
return sparsity_detection_alg(f, AutoSparse(ad; sparsity_detector = f.sparsity))
181+
end
182+
183+
function sparsity_detection_alg(f::NonlinearFunction, ad::AutoSparse)
180184
if f.sparsity === nothing
181185
if f.jac_prototype === nothing
182186
is_extension_loaded(Val(:Symbolics)) && return SymbolicsSparsityDetection()
@@ -200,8 +204,7 @@ end
200204
end
201205

202206
if SciMLBase.has_colorvec(f)
203-
return PrecomputedJacobianColorvec(; jac_prototype,
204-
f.colorvec,
207+
return PrecomputedJacobianColorvec(; jac_prototype, f.colorvec,
205208
partition_by_rows = ADTypes.mode(ad) isa ADTypes.ReverseMode)
206209
else
207210
return JacPrototypeSparsityDetection(; jac_prototype)

0 commit comments

Comments
 (0)