@@ -54,7 +54,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
54
54
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
55
55
fu = f. resid_prototype === nothing ? (iip ? _mutable_zero (u) : _mutable (f (u, p))) :
56
56
(iip ? deepcopy (f. resid_prototype) : f. resid_prototype)
57
- if ! has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
57
+ if ! has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) # || needsJᵀJ)
58
58
sd = sparsity_detection_alg (f, alg. ad)
59
59
ad = alg. ad
60
60
jac_cache = iip ? sparse_jacobian_cache (ad, sd, uf, fu, _maybe_mutable (u, ad)) :
@@ -92,9 +92,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
92
92
du = _mutable_zero (u)
93
93
94
94
if needsJᵀJ
95
- JᵀJ = __init_JᵀJ (J)
96
- # FIXME : This needs to be handled better for JacVec Operator
97
- Jᵀfu = J ' * _vec (fu )
95
+ # TODO : Pass in `jac_transpose_autodiff`
96
+ JᵀJ, Jᵀfu = __init_JᵀJ (J, _vec (fu), uf, u;
97
+ jac_autodiff = __get_nonsparse_ad (alg . ad) )
98
98
end
99
99
100
100
if linsolve_init
@@ -120,21 +120,68 @@ function __setup_linsolve(A, b, u, p, alg)
120
120
nothing )... , weight)
121
121
return init (linprob, alg. linsolve; alias_A = true , alias_b = true , Pl, Pr)
122
122
end
123
+ __setup_linsolve (A:: KrylovJ ᵀJ, b, u, p, alg) = __setup_linsolve (A. JᵀJ, b, u, p, alg)
123
124
124
125
__get_nonsparse_ad (:: AutoSparseForwardDiff ) = AutoForwardDiff ()
125
126
__get_nonsparse_ad (:: AutoSparseFiniteDiff ) = AutoFiniteDiff ()
126
127
__get_nonsparse_ad (:: AutoSparseZygote ) = AutoZygote ()
127
128
__get_nonsparse_ad (ad) = ad
128
129
129
- __init_JᵀJ (J:: Number ) = zero (J)
130
- __init_JᵀJ (J:: AbstractArray ) = J' * J
131
- __init_JᵀJ (J:: StaticArray ) = MArray {Tuple{size(J, 2), size(J, 2)}, eltype(J)} (undef)
130
+ __init_JᵀJ (J:: Number , args... ; kwargs... ) = zero (J), zero (J)
131
+ function __init_JᵀJ (J:: AbstractArray , fu, args... ; kwargs... )
132
+ JᵀJ = J' * J
133
+ Jᵀfu = J' * fu
134
+ return JᵀJ, Jᵀfu
135
+ end
136
+ function __init_JᵀJ (J:: StaticArray , fu, args... ; kwargs... )
137
+ JᵀJ = MArray {Tuple{size(J, 2), size(J, 2)}, eltype(J)} (undef)
138
+ return JᵀJ, J' * fu
139
+ end
140
+ function __init_JᵀJ (J:: FunctionOperator , fu, uf, u, args... ;
141
+ jac_transpose_autodiff = nothing , jac_autodiff = nothing , kwargs... )
142
+ autodiff = __concrete_jac_transpose_autodiff (jac_transpose_autodiff, jac_autodiff, uf)
143
+ Jᵀ = VecJac (uf, u; autodiff)
144
+ JᵀJ_op = SciMLOperators. cache_operator (Jᵀ * J, u)
145
+ JᵀJ = KrylovJᵀJ (JᵀJ_op, Jᵀ)
146
+ Jᵀfu = Jᵀ * fu
147
+ return JᵀJ, Jᵀfu
148
+ end
149
+
150
+ @concrete struct KrylovJᵀJ
151
+ JᵀJ
152
+ Jᵀ
153
+ end
154
+
155
+ SciMLBase. isinplace (JᵀJ:: KrylovJ ᵀJ) = isinplace (JᵀJ. Jᵀ)
156
+
157
+ function __concrete_jac_transpose_autodiff (jac_transpose_autodiff, jac_autodiff, uf)
158
+ if jac_transpose_autodiff === nothing
159
+ if isinplace (uf)
160
+ # VecJac can be only FiniteDiff
161
+ return AutoFiniteDiff ()
162
+ else
163
+ # Short circuit if we see that FiniteDiff was used for J computation
164
+ jac_autodiff isa AutoFiniteDiff && return jac_autodiff
165
+ # Check if Zygote is loaded then use Zygote else use FiniteDiff
166
+ if haskey (Base. loaded_modules,
167
+ Base. PkgId (Base. UUID (" e88e6eb3-aa80-5325-afca-941959d7151f" ), " Zygote" ))
168
+ return AutoZygote ()
169
+ else
170
+ return AutoFiniteDiff ()
171
+ end
172
+ end
173
+ else
174
+ return __get_nonsparse_ad (jac_transpose_autodiff)
175
+ end
176
+ end
132
177
133
178
__maybe_symmetric (x) = Symmetric (x)
134
179
__maybe_symmetric (x:: Number ) = x
135
180
# LinearSolve with `nothing` doesn't dispatch correctly here
136
181
__maybe_symmetric (x:: StaticArray ) = x
137
182
__maybe_symmetric (x:: SparseArrays.AbstractSparseMatrix ) = x
183
+ __maybe_symmetric (x:: SciMLOperators.AbstractSciMLOperator ) = x
184
+ __maybe_symmetric (x:: KrylovJ ᵀJ) = x
138
185
139
186
# # Special Handling for Scalars
140
187
function jacobian_caches (alg:: AbstractNonlinearSolveAlgorithm , f:: F , u:: Number , p,
@@ -145,3 +192,37 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
145
192
needsJᵀJ && return uf, nothing , u, nothing , nothing , u, u, u
146
193
return uf, nothing , u, nothing , nothing , u
147
194
end
195
+
196
+ function __update_JᵀJ! (iip:: Val , cache, sym:: Symbol , J)
197
+ return __update_JᵀJ! (iip, cache, sym, getproperty (cache, sym), J)
198
+ end
199
+ __update_JᵀJ! (:: Val{false} , cache, sym:: Symbol , _, J) = setproperty! (cache, sym, J' * J)
200
+ __update_JᵀJ! (:: Val{true} , cache, sym:: Symbol , _, J) = mul! (getproperty (cache, sym), J' , J)
201
+ __update_JᵀJ! (:: Val{false} , cache, sym:: Symbol , H:: KrylovJ ᵀJ, J) = H
202
+ __update_JᵀJ! (:: Val{true} , cache, sym:: Symbol , H:: KrylovJ ᵀJ, J) = H
203
+
204
+ function __update_Jᵀf! (iip:: Val , cache, sym1:: Symbol , sym2:: Symbol , J, fu)
205
+ return __update_Jᵀf! (iip, cache, sym1, sym2, getproperty (cache, sym2), J, fu)
206
+ end
207
+ function __update_Jᵀf! (:: Val{false} , cache, sym1:: Symbol , sym2:: Symbol , _, J, fu)
208
+ return setproperty! (cache, sym1, J' * fu)
209
+ end
210
+ function __update_Jᵀf! (:: Val{true} , cache, sym1:: Symbol , sym2:: Symbol , _, J, fu)
211
+ return mul! (getproperty (cache, sym1), J' , fu)
212
+ end
213
+ function __update_Jᵀf! (:: Val{false} , cache, sym1:: Symbol , sym2:: Symbol , H:: KrylovJ ᵀJ, J, fu)
214
+ return setproperty! (cache, sym1, H. Jᵀ * fu)
215
+ end
216
+ function __update_Jᵀf! (:: Val{true} , cache, sym1:: Symbol , sym2:: Symbol , H:: KrylovJ ᵀJ, J, fu)
217
+ return mul! (getproperty (cache, sym1), H. Jᵀ, fu)
218
+ end
219
+
220
+ # Left-Right Multiplication
221
+ __lr_mul (:: Val , H, g) = dot (g, H, g)
222
+ # # TODO : Use a cache here to avoid allocations
223
+ __lr_mul (:: Val{false} , H:: KrylovJ ᵀJ, g) = dot (g, H. JᵀJ, g)
224
+ function __lr_mul (:: Val{true} , H:: KrylovJ ᵀJ, g)
225
+ c = similar (g)
226
+ mul! (c, H. JᵀJ, g)
227
+ return dot (g, c)
228
+ end
0 commit comments