211
211
function SciMLBase. reinit! (
212
212
cache:: SingleFactorizeManifoldProjectionCache{iip} , u; p = cache. p) where {iip}
213
213
if ! cache. first_call || (cache. ũ != = u || cache. p != = p)
214
- if cache. manifold_jacobian != = nothing
215
- if iip
216
- cache. manifold_jacobian (cache. J, u, p)
217
- else
218
- cache. J = cache. manifold_jacobian (u, p)
219
- end
220
- else
221
- if iip
222
- DI. jacobian! (cache. manifold, cache. gu_cache, cache. J,
223
- cache. di_extras, cache. autodiff, u, Constant (p))
224
- else
225
- DI. jacobian! (cache. manifold, cache. J, cache. di_extras,
226
- cache. autodiff, u, Constant (p))
227
- end
228
- end
214
+ compute_manifold_jacobian! (cache. J, cache. manifold_jacobian, cache. autodiff,
215
+ Val (iip), cache. manifold, cache. gu_cache, u, p, cache. di_extras)
229
216
mul! (cache. JJᵀ, cache. J, cache. J' )
230
217
cache. JJᵀfact = safe_factorize! (cache. JJᵀ)
231
218
end
236
223
237
224
default_abstol (:: Type{T} ) where {T} = real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 )
238
225
239
- function init_manifold_projection (:: Val{iip} , manifold, autodiff, manifold_jacobian, ũ,
226
+ function init_manifold_projection (IIP :: Val{iip} , manifold, autodiff, manifold_jacobian, ũ,
240
227
p; abstol = default_abstol (eltype (ũ)), maxiters = 1000 ,
241
228
resid_prototype = nothing ) where {iip}
242
229
if iip
@@ -254,26 +241,8 @@ function init_manifold_projection(::Val{iip}, manifold, autodiff, manifold_jacob
254
241
λ = manifold (ũ, p)
255
242
end
256
243
257
- if manifold_jacobian != = nothing
258
- if iip
259
- J = similar (ũ, promote_type (eltype (gu), eltype (ũ)), (length (gu), length (ũ)))
260
- manifold_jacobian (J, ũ, p)
261
- else
262
- J = manifold_jacobian (ũ, p)
263
- end
264
- di_extras = nothing
265
- elseif autodiff != = nothing
266
- if iip
267
- di_extras = DI. prepare_jacobian (manifold, gu, autodiff, ũ, Constant (p))
268
- J = DI. jacobian (manifold, gu, di_extras, autodiff, ũ, Constant (p))
269
- else
270
- di_extras = DI. prepare_jacobian (manifold, autodiff, ũ, Constant (p))
271
- J = DI. jacobian (manifold, di_extras, autodiff, ũ, Constant (p))
272
- end
273
- else
274
- error (" `autodiff` is set to `nothing` and analytic manifold jacobian is not \
275
- provided." )
276
- end
244
+ J, di_extras = setup_manifold_jacobian (manifold_jacobian, autodiff, IIP, manifold,
245
+ gu, ũ, p)
277
246
JJᵀ = J * J'
278
247
JJᵀfact = safe_factorize! (JJᵀ)
279
248
@@ -317,9 +286,56 @@ function SciMLBase.solve!(cache::SingleFactorizeManifoldProjectionCache{iip}) wh
317
286
ifelse (internal_solve_failed, ReturnCode. ConvergenceFailure, ReturnCode. Success))
318
287
end
319
288
289
+ function setup_manifold_jacobian (
290
+ manifold_jacobian:: M , autodiff, :: Val{iip} , manifold, gu, ũ, p) where {M, iip}
291
+ if iip
292
+ J = similar (ũ, promote_type (eltype (gu), eltype (ũ)), (length (gu), length (ũ)))
293
+ manifold_jacobian (J, ũ, p)
294
+ else
295
+ J = manifold_jacobian (ũ, p)
296
+ end
297
+ return J, nothing
298
+ end
299
+
300
+ function setup_manifold_jacobian (
301
+ :: Nothing , autodiff, :: Val{iip} , manifold, gu, ũ, p) where {iip}
302
+ if iip
303
+ di_extras = DI. prepare_jacobian (manifold, gu, autodiff, ũ, Constant (p))
304
+ J = DI. jacobian (manifold, gu, di_extras, autodiff, ũ, Constant (p))
305
+ else
306
+ di_extras = DI. prepare_jacobian (manifold, autodiff, ũ, Constant (p))
307
+ J = DI. jacobian (manifold, di_extras, autodiff, ũ, Constant (p))
308
+ end
309
+ return J, di_extras
310
+ end
311
+
312
+ function compute_manifold_jacobian! (J, manifold_jacobian, autodiff, :: Val{iip} ,
313
+ manifold, gu, ũ, p, di_extras) where {iip}
314
+ if iip
315
+ manifold_jacobian (J, ũ, p)
316
+ else
317
+ J = manifold_jacobian (ũ, p)
318
+ end
319
+ return J
320
+ end
321
+
322
+ function compute_manifold_jacobian! (J, :: Nothing , autodiff, :: Val{iip} , manifold, gu,
323
+ ũ, p, di_extras) where {iip}
324
+ if iip
325
+ DI. jacobian! (manifold, gu, J, di_extras, autodiff, ũ, Constant (p))
326
+ else
327
+ DI. jacobian! (manifold, J, di_extras, autodiff, ũ, Constant (p))
328
+ end
329
+ return J
330
+ end
331
+
332
+ function setup_manifold_jacobian (:: Nothing , :: Nothing , args... )
333
+ error (" `autodiff` is set to `nothing` and analytic manifold jacobian is not provided." )
334
+ end
335
+
320
336
function safe_factorize! (A:: AbstractMatrix )
321
337
if issquare (A)
322
- fact = LinearAlgebra. lu (A; check = false )
338
+ fact = LinearAlgebra. cholesky (A; check = false )
323
339
fact_sucessful (fact) && return fact
324
340
elseif size (A, 1 ) > size (A, 2 )
325
341
fact = LinearAlgebra. qr (A)
0 commit comments