Skip to content

Commit d44a4b2

Browse files
committed
refactor: modularize the code
1 parent 4312cf0 commit d44a4b2

File tree

1 file changed

+53
-37
lines changed

1 file changed

+53
-37
lines changed

src/manifold.jl

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -211,21 +211,8 @@ end
211211
function SciMLBase.reinit!(
212212
cache::SingleFactorizeManifoldProjectionCache{iip}, u; p = cache.p) where {iip}
213213
if !cache.first_call || (cache.!== u || cache.p !== p)
214-
if cache.manifold_jacobian !== nothing
215-
if iip
216-
cache.manifold_jacobian(cache.J, u, p)
217-
else
218-
cache.J = cache.manifold_jacobian(u, p)
219-
end
220-
else
221-
if iip
222-
DI.jacobian!(cache.manifold, cache.gu_cache, cache.J,
223-
cache.di_extras, cache.autodiff, u, Constant(p))
224-
else
225-
DI.jacobian!(cache.manifold, cache.J, cache.di_extras,
226-
cache.autodiff, u, Constant(p))
227-
end
228-
end
214+
compute_manifold_jacobian!(cache.J, cache.manifold_jacobian, cache.autodiff,
215+
Val(iip), cache.manifold, cache.gu_cache, u, p, cache.di_extras)
229216
mul!(cache.JJᵀ, cache.J, cache.J')
230217
cache.JJᵀfact = safe_factorize!(cache.JJᵀ)
231218
end
@@ -236,7 +223,7 @@ end
236223

237224
default_abstol(::Type{T}) where {T} = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)
238225

239-
function init_manifold_projection(::Val{iip}, manifold, autodiff, manifold_jacobian, ũ,
226+
function init_manifold_projection(IIP::Val{iip}, manifold, autodiff, manifold_jacobian, ũ,
240227
p; abstol = default_abstol(eltype(ũ)), maxiters = 1000,
241228
resid_prototype = nothing) where {iip}
242229
if iip
@@ -254,26 +241,8 @@ function init_manifold_projection(::Val{iip}, manifold, autodiff, manifold_jacob
254241
λ = manifold(ũ, p)
255242
end
256243

257-
if manifold_jacobian !== nothing
258-
if iip
259-
J = similar(ũ, promote_type(eltype(gu), eltype(ũ)), (length(gu), length(ũ)))
260-
manifold_jacobian(J, ũ, p)
261-
else
262-
J = manifold_jacobian(ũ, p)
263-
end
264-
di_extras = nothing
265-
elseif autodiff !== nothing
266-
if iip
267-
di_extras = DI.prepare_jacobian(manifold, gu, autodiff, ũ, Constant(p))
268-
J = DI.jacobian(manifold, gu, di_extras, autodiff, ũ, Constant(p))
269-
else
270-
di_extras = DI.prepare_jacobian(manifold, autodiff, ũ, Constant(p))
271-
J = DI.jacobian(manifold, di_extras, autodiff, ũ, Constant(p))
272-
end
273-
else
274-
error("`autodiff` is set to `nothing` and analytic manifold jacobian is not \
275-
provided.")
276-
end
244+
J, di_extras = setup_manifold_jacobian(manifold_jacobian, autodiff, IIP, manifold,
245+
gu, ũ, p)
277246
JJᵀ = J * J'
278247
JJᵀfact = safe_factorize!(JJᵀ)
279248

@@ -317,9 +286,56 @@ function SciMLBase.solve!(cache::SingleFactorizeManifoldProjectionCache{iip}) wh
317286
ifelse(internal_solve_failed, ReturnCode.ConvergenceFailure, ReturnCode.Success))
318287
end
319288

289+
function setup_manifold_jacobian(
290+
manifold_jacobian::M, autodiff, ::Val{iip}, manifold, gu, ũ, p) where {M, iip}
291+
if iip
292+
J = similar(ũ, promote_type(eltype(gu), eltype(ũ)), (length(gu), length(ũ)))
293+
manifold_jacobian(J, ũ, p)
294+
else
295+
J = manifold_jacobian(ũ, p)
296+
end
297+
return J, nothing
298+
end
299+
300+
function setup_manifold_jacobian(
301+
::Nothing, autodiff, ::Val{iip}, manifold, gu, ũ, p) where {iip}
302+
if iip
303+
di_extras = DI.prepare_jacobian(manifold, gu, autodiff, ũ, Constant(p))
304+
J = DI.jacobian(manifold, gu, di_extras, autodiff, ũ, Constant(p))
305+
else
306+
di_extras = DI.prepare_jacobian(manifold, autodiff, ũ, Constant(p))
307+
J = DI.jacobian(manifold, di_extras, autodiff, ũ, Constant(p))
308+
end
309+
return J, di_extras
310+
end
311+
312+
function compute_manifold_jacobian!(J, manifold_jacobian, autodiff, ::Val{iip},
313+
manifold, gu, ũ, p, di_extras) where {iip}
314+
if iip
315+
manifold_jacobian(J, ũ, p)
316+
else
317+
J = manifold_jacobian(ũ, p)
318+
end
319+
return J
320+
end
321+
322+
function compute_manifold_jacobian!(J, ::Nothing, autodiff, ::Val{iip}, manifold, gu,
323+
ũ, p, di_extras) where {iip}
324+
if iip
325+
DI.jacobian!(manifold, gu, J, di_extras, autodiff, ũ, Constant(p))
326+
else
327+
DI.jacobian!(manifold, J, di_extras, autodiff, ũ, Constant(p))
328+
end
329+
return J
330+
end
331+
332+
function setup_manifold_jacobian(::Nothing, ::Nothing, args...)
333+
error("`autodiff` is set to `nothing` and analytic manifold jacobian is not provided.")
334+
end
335+
320336
function safe_factorize!(A::AbstractMatrix)
321337
if issquare(A)
322-
fact = LinearAlgebra.lu(A; check = false)
338+
fact = LinearAlgebra.cholesky(A; check = false)
323339
fact_sucessful(fact) && return fact
324340
elseif size(A, 1) > size(A, 2)
325341
fact = LinearAlgebra.qr(A)

0 commit comments

Comments
 (0)