Skip to content

Commit 254c1d5

Browse files
committed
feat: use DI for dense jacobians
1 parent 21bdd82 commit 254c1d5

File tree

1 file changed

+55
-36
lines changed

1 file changed

+55
-36
lines changed

src/internal/jacobian.jl

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,25 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`.
3131
@concrete mutable struct JacobianCache{iip} <: AbstractNonlinearSolveJacobianCache{iip}
3232
J
3333
f
34-
uf
3534
fu
3635
u
3736
p
38-
jac_cache
3937
alg
4038
stats::NLStats
4139
autodiff
4240
di_extras
41+
sdifft_extras
4342
end
4443

4544
function reinit_cache!(cache::JacobianCache{iip}, args...; p = cache.p,
4645
u0 = cache.u, kwargs...) where {iip}
4746
cache.u = u0
4847
cache.p = p
49-
cache.uf = JacobianWrapper{iip}(cache.f, p)
5048
end
5149

5250
function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
5351
vjp_autodiff = nothing, jvp_autodiff = nothing, linsolve = missing) where {F}
5452
iip = isinplace(prob)
55-
uf = JacobianWrapper{iip}(f, p)
56-
57-
autodiff = get_concrete_forward_ad(autodiff, prob; check_forward_mode = false)
5853

5954
has_analytic_jac = SciMLBase.has_jac(f)
6055
linsolve_needs_jac = concrete_jac(alg) === nothing && (linsolve === missing ||
@@ -65,12 +60,31 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
6560
@bb fu = similar(fu_)
6661

6762
if !has_analytic_jac && needs_jac
63+
autodiff = get_concrete_forward_ad(autodiff, prob; check_forward_mode = false)
6864
sd = __sparsity_detection_alg(f, autodiff)
69-
jac_cache = iip ? sparse_jacobian_cache(autodiff, sd, uf, fu, u) :
70-
sparse_jacobian_cache(
71-
autodiff, sd, uf, __maybe_mutable(u, autodiff); fx = fu)
65+
sparse_jac = !(sd isa NoSparsityDetection)
66+
# Eventually we want to do everything via DI. But for now, we just do the dense via DI
67+
if sparse_jac
68+
di_extras = nothing
69+
uf = JacobianWrapper{iip}(f, p)
70+
sdifft_extras = if iip
71+
sparse_jacobian_cache(autodiff, sd, uf, fu, u)
72+
else
73+
sparse_jacobian_cache(
74+
autodiff, sd, uf, __maybe_mutable(u, autodiff); fx = fu)
75+
end
76+
else
77+
sdifft_extras = nothing
78+
di_extras = if iip
79+
DI.prepare_jacobian(f, fu, autodiff, u, Constant(p))
80+
else
81+
DI.prepare_jacobian(f, autodiff, u, Constant(p))
82+
end
83+
end
7284
else
73-
jac_cache = nothing
85+
sparse_jac = false
86+
di_extras = nothing
87+
sdifft_extras = nothing
7488
end
7589

7690
J = if !needs_jac
@@ -80,36 +94,34 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
8094
vjp_autodiff, prob, Val(false); check_reverse_mode = false)
8195
JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff)
8296
else
83-
if has_analytic_jac
84-
f.jac_prototype === nothing ?
85-
__similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) :
86-
copy(f.jac_prototype)
87-
elseif f.jac_prototype === nothing
88-
zero(init_jacobian(jac_cache; preserve_immutable = Val(true)))
97+
if f.jac_prototype === nothing
98+
if !sparse_jac
99+
__similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
100+
else
101+
zero(init_jacobian(sdifft_extras; preserve_immutable = Val(true)))
102+
end
89103
else
90-
f.jac_prototype
104+
similar(f.jac_prototype)
91105
end
92106
end
93107

94108
return JacobianCache{iip}(
95-
J, f, uf, fu, u, p, jac_cache, alg, stats, autodiff, nothing)
109+
J, f, fu, u, p, alg, stats, autodiff, di_extras, sdifft_extras)
96110
end
97111

98112
function JacobianCache(prob, alg, f::F, ::Number, u::Number, p; stats,
99113
autodiff = nothing, kwargs...) where {F}
100114
fu = f(u, p)
101115
if SciMLBase.has_jac(f) || SciMLBase.has_vjp(f) || SciMLBase.has_jvp(f)
102-
return JacobianCache{false}(
103-
u, f, nothing, fu, u, p, nothing, alg, stats, autodiff, nothing)
116+
return JacobianCache{false}(u, f, fu, u, p, alg, stats, autodiff, nothing)
104117
end
105118
autodiff = get_concrete_forward_ad(autodiff, prob; check_forward_mode = false)
106119
di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p))
107-
return JacobianCache{false}(
108-
u, f, nothing, fu, u, p, nothing, alg, stats, autodiff, di_extras)
120+
return JacobianCache{false}(u, f, fu, u, p, alg, stats, autodiff, di_extras, nothing)
109121
end
110122

111-
@inline (cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
112-
@inline function (cache::JacobianCache)(::Nothing)
123+
(cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
124+
function (cache::JacobianCache)(::Nothing)
113125
cache.J isa JacobianOperator &&
114126
return StatefulJacobianOperator(cache.J, cache.u, cache.p)
115127
return cache.J
@@ -136,23 +148,31 @@ function (cache::JacobianCache{iip})(
136148
J::Union{AbstractMatrix, Nothing}, u, p = cache.p) where {iip}
137149
cache.stats.njacs += 1
138150
if iip
139-
if has_jac(cache.f)
151+
if SciMLBase.has_jac(cache.f)
140152
cache.f.jac(J, u, p)
153+
elseif cache.di_extras !== nothing
154+
DI.jacobian!(
155+
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p))
141156
else
142-
sparse_jacobian!(J, cache.autodiff, cache.jac_cache, cache.uf, cache.fu, u)
157+
uf = JacobianWrapper{iip}(cache.f, p)
158+
sparse_jacobian!(J, cache.autodiff, cache.jac_cache, uf, cache.fu, u)
143159
end
144-
J_ = J
160+
return J
145161
else
146-
J_ = if has_jac(cache.f)
147-
cache.f.jac(u, p)
148-
elseif __can_setindex(typeof(J))
149-
sparse_jacobian!(J, cache.autodiff, cache.jac_cache, cache.uf, u)
150-
J
162+
if SciMLBase.has_jac(cache.f)
163+
return cache.f.jac(u, p)
164+
elseif cache.di_extras !== nothing
165+
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
151166
else
152-
sparse_jacobian(cache.autodiff, cache.jac_cache, cache.uf, u)
167+
uf = JacobianWrapper{iip}(cache.f, p)
168+
if __can_setindex(typeof(J))
169+
sparse_jacobian!(J, cache.autodiff, cache.sdifft_extras, uf, u)
170+
return J
171+
else
172+
return sparse_jacobian(cache.autodiff, cache.sdifft_extras, uf, u)
173+
end
153174
end
154175
end
155-
return J_
156176
end
157177

158178
# Sparsity Detection Choices
@@ -183,8 +203,7 @@ end
183203
if SciMLBase.has_colorvec(f)
184204
return PrecomputedJacobianColorvec(; jac_prototype,
185205
f.colorvec,
186-
partition_by_rows = (ad isa AutoSparse &&
187-
ADTypes.mode(ad) isa ADTypes.ReverseMode))
206+
partition_by_rows = ADTypes.mode(ad) isa ADTypes.ReverseMode)
188207
else
189208
return JacPrototypeSparsityDetection(; jac_prototype)
190209
end

0 commit comments

Comments
 (0)