@@ -31,30 +31,25 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`.
31
31
@concrete mutable struct JacobianCache{iip} <: AbstractNonlinearSolveJacobianCache{iip}
32
32
J
33
33
f
34
- uf
35
34
fu
36
35
u
37
36
p
38
- jac_cache
39
37
alg
40
38
stats:: NLStats
41
39
autodiff
42
40
di_extras
41
+ sdifft_extras
43
42
end
44
43
45
44
function reinit_cache! (cache:: JacobianCache{iip} , args... ; p = cache. p,
46
45
u0 = cache. u, kwargs... ) where {iip}
47
46
cache. u = u0
48
47
cache. p = p
49
- cache. uf = JacobianWrapper {iip} (cache. f, p)
50
48
end
51
49
52
50
function JacobianCache (prob, alg, f:: F , fu_, u, p; stats, autodiff = nothing ,
53
51
vjp_autodiff = nothing , jvp_autodiff = nothing , linsolve = missing ) where {F}
54
52
iip = isinplace (prob)
55
- uf = JacobianWrapper {iip} (f, p)
56
-
57
- autodiff = get_concrete_forward_ad (autodiff, prob; check_forward_mode = false )
58
53
59
54
has_analytic_jac = SciMLBase. has_jac (f)
60
55
linsolve_needs_jac = concrete_jac (alg) === nothing && (linsolve === missing ||
@@ -65,12 +60,31 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
65
60
@bb fu = similar (fu_)
66
61
67
62
if ! has_analytic_jac && needs_jac
63
+ autodiff = get_concrete_forward_ad (autodiff, prob; check_forward_mode = false )
68
64
sd = __sparsity_detection_alg (f, autodiff)
69
- jac_cache = iip ? sparse_jacobian_cache (autodiff, sd, uf, fu, u) :
70
- sparse_jacobian_cache (
71
- autodiff, sd, uf, __maybe_mutable (u, autodiff); fx = fu)
65
+ sparse_jac = ! (sd isa NoSparsityDetection)
66
+ # Eventually we want to do everything via DI. But for now, we just do the dense via DI
67
+ if sparse_jac
68
+ di_extras = nothing
69
+ uf = JacobianWrapper {iip} (f, p)
70
+ sdifft_extras = if iip
71
+ sparse_jacobian_cache (autodiff, sd, uf, fu, u)
72
+ else
73
+ sparse_jacobian_cache (
74
+ autodiff, sd, uf, __maybe_mutable (u, autodiff); fx = fu)
75
+ end
76
+ else
77
+ sdifft_extras = nothing
78
+ di_extras = if iip
79
+ DI. prepare_jacobian (f, fu, autodiff, u, Constant (p))
80
+ else
81
+ DI. prepare_jacobian (f, autodiff, u, Constant (p))
82
+ end
83
+ end
72
84
else
73
- jac_cache = nothing
85
+ sparse_jac = false
86
+ di_extras = nothing
87
+ sdifft_extras = nothing
74
88
end
75
89
76
90
J = if ! needs_jac
@@ -80,36 +94,34 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
80
94
vjp_autodiff, prob, Val (false ); check_reverse_mode = false )
81
95
JacobianOperator (prob, fu, u; jvp_autodiff, vjp_autodiff)
82
96
else
83
- if has_analytic_jac
84
- f . jac_prototype === nothing ?
85
- __similar (fu, promote_type (eltype (fu), eltype (u)), length (fu), length (u)) :
86
- copy (f . jac_prototype)
87
- elseif f . jac_prototype === nothing
88
- zero ( init_jacobian (jac_cache; preserve_immutable = Val ( true )))
97
+ if f . jac_prototype === nothing
98
+ if ! sparse_jac
99
+ __similar (fu, promote_type (eltype (fu), eltype (u)), length (fu), length (u))
100
+ else
101
+ zero ( init_jacobian (sdifft_extras; preserve_immutable = Val ( true )))
102
+ end
89
103
else
90
- f. jac_prototype
104
+ similar ( f. jac_prototype)
91
105
end
92
106
end
93
107
94
108
return JacobianCache {iip} (
95
- J, f, uf, fu, u, p, jac_cache, alg, stats, autodiff, nothing )
109
+ J, f, fu, u, p, alg, stats, autodiff, di_extras, sdifft_extras )
96
110
end
97
111
98
112
function JacobianCache (prob, alg, f:: F , :: Number , u:: Number , p; stats,
99
113
autodiff = nothing , kwargs... ) where {F}
100
114
fu = f (u, p)
101
115
if SciMLBase. has_jac (f) || SciMLBase. has_vjp (f) || SciMLBase. has_jvp (f)
102
- return JacobianCache {false} (
103
- u, f, nothing , fu, u, p, nothing , alg, stats, autodiff, nothing )
116
+ return JacobianCache {false} (u, f, fu, u, p, alg, stats, autodiff, nothing )
104
117
end
105
118
autodiff = get_concrete_forward_ad (autodiff, prob; check_forward_mode = false )
106
119
di_extras = DI. prepare_derivative (f, autodiff, u, Constant (prob. p))
107
- return JacobianCache {false} (
108
- u, f, nothing , fu, u, p, nothing , alg, stats, autodiff, di_extras)
120
+ return JacobianCache {false} (u, f, fu, u, p, alg, stats, autodiff, di_extras, nothing )
109
121
end
110
122
111
- @inline (cache:: JacobianCache )(u = cache. u) = cache (cache. J, u, cache. p)
112
- @inline function (cache:: JacobianCache )(:: Nothing )
123
+ (cache:: JacobianCache )(u = cache. u) = cache (cache. J, u, cache. p)
124
+ function (cache:: JacobianCache )(:: Nothing )
113
125
cache. J isa JacobianOperator &&
114
126
return StatefulJacobianOperator (cache. J, cache. u, cache. p)
115
127
return cache. J
@@ -136,23 +148,31 @@ function (cache::JacobianCache{iip})(
136
148
J:: Union{AbstractMatrix, Nothing} , u, p = cache. p) where {iip}
137
149
cache. stats. njacs += 1
138
150
if iip
139
- if has_jac (cache. f)
151
+ if SciMLBase . has_jac (cache. f)
140
152
cache. f. jac (J, u, p)
153
+ elseif cache. di_extras != = nothing
154
+ DI. jacobian! (
155
+ cache. f, cache. fu, J, cache. di_extras, cache. autodiff, u, Constant (p))
141
156
else
142
- sparse_jacobian! (J, cache. autodiff, cache. jac_cache, cache. uf, cache. fu, u)
157
+ uf = JacobianWrapper {iip} (cache. f, p)
158
+ sparse_jacobian! (J, cache. autodiff, cache. jac_cache, uf, cache. fu, u)
143
159
end
144
- J_ = J
160
+ return J
145
161
else
146
- J_ = if has_jac (cache. f)
147
- cache. f. jac (u, p)
148
- elseif __can_setindex (typeof (J))
149
- sparse_jacobian! (J, cache. autodiff, cache. jac_cache, cache. uf, u)
150
- J
162
+ if SciMLBase. has_jac (cache. f)
163
+ return cache. f. jac (u, p)
164
+ elseif cache. di_extras != = nothing
165
+ return DI. jacobian (cache. f, cache. di_extras, cache. autodiff, u, Constant (p))
151
166
else
152
- sparse_jacobian (cache. autodiff, cache. jac_cache, cache. uf, u)
167
+ uf = JacobianWrapper {iip} (cache. f, p)
168
+ if __can_setindex (typeof (J))
169
+ sparse_jacobian! (J, cache. autodiff, cache. sdifft_extras, uf, u)
170
+ return J
171
+ else
172
+ return sparse_jacobian (cache. autodiff, cache. sdifft_extras, uf, u)
173
+ end
153
174
end
154
175
end
155
- return J_
156
176
end
157
177
158
178
# Sparsity Detection Choices
183
203
if SciMLBase. has_colorvec (f)
184
204
return PrecomputedJacobianColorvec (; jac_prototype,
185
205
f. colorvec,
186
- partition_by_rows = (ad isa AutoSparse &&
187
- ADTypes. mode (ad) isa ADTypes. ReverseMode))
206
+ partition_by_rows = ADTypes. mode (ad) isa ADTypes. ReverseMode)
188
207
else
189
208
return JacPrototypeSparsityDetection (; jac_prototype)
190
209
end
0 commit comments