Skip to content

Commit 7b52768

Browse files
committed
feat: lagrangian multiplier based projection algorithm
1 parent d44a4b2 commit 7b52768

File tree

2 files changed

+126
-26
lines changed

2 files changed

+126
-26
lines changed

src/manifold.jl

Lines changed: 105 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ properties.
4444
`nlsolve` is not `missing`.
4545
- `autodiff`: The autodifferentiation algorithm to use to compute the Jacobian if
4646
`manifold_jacobian` is not specified. This must be specified if `manifold_jacobian` is
47-
not specified and `nlsolve` is `missing`. If `nlsolve` is not `missing`, then
48-
`autodiff` is ignored.
47+
not specified.
4948
- `manifold_jacobian`: The Jacobian of the manifold (wrt the state). This has the same
5049
signature as `manifold` and the first argument is the Jacobian if inplace.
5150
@@ -118,13 +117,7 @@ function (proj::ManifoldProjection)(integrator)
118117
proj.manifold_jacobian !== nothing && (proj.manifold_jacobian.t = integrator.t)
119118

120119
SciMLBase.reinit!(proj.nlcache, integrator.u; integrator.p)
121-
122-
if proj.nlsolve === missing
123-
_, u, retcode = SciMLBase.solve!(proj.nlcache)
124-
else
125-
sol = SciMLBase.solve!(proj.nlcache)
126-
(; u, retcode) = sol
127-
end
120+
_, u, retcode = SciMLBase.solve!(proj.nlcache)
128121

129122
if !SciMLBase.successful_retcode(retcode)
130123
SciMLBase.terminate!(integrator, retcode)
@@ -146,17 +139,17 @@ function initialize_manifold_projection(affect!::ManifoldProjection, u, t, integ
146139
(affect!.manifold_jacobian.autonomous = autonomous)
147140
end
148141

142+
affect!.manifold.t = t
143+
affect!.manifold_jacobian !== nothing && (affect!.manifold_jacobian.t = t)
144+
149145
if affect!.nlsolve === missing
150-
affect!.manifold.t = t
151-
affect!.manifold_jacobian !== nothing && (affect!.manifold_jacobian.t = t)
152146
cache = init_manifold_projection(
153147
Val(SciMLBase.isinplace(integrator.f)), affect!.manifold, affect!.autodiff,
154148
affect!.manifold_jacobian, u, integrator.p; affect!.kwargs...)
155149
else
156-
# nlfunc = NonlinearFunction{iip}(affect!.g; affect!.resid_prototype)
157-
# nlprob = NonlinearProblem(nlfunc, u, integrator.p)
158-
# affect!.nlcache = init(nlprob, affect!.nlsolve; affect!.kwargs...)
159-
error("Not Implemented")
150+
cache = init_manifold_projection_nonlinear_problem(
151+
Val(SciMLBase.isinplace(integrator.f)), affect!.manifold, affect!.autodiff,
152+
affect!.manifold_jacobian, u, integrator.p, affect!.nlsolve; affect!.kwargs...)
160153
end
161154
affect!.nlcache = cache
162155
u_modified!(integrator, false)
@@ -187,6 +180,97 @@ function (f::UntypedNonAutonomousFunction)(res, u, p)
187180
end
188181
(f::UntypedNonAutonomousFunction)(u, p) = f.autonomous ? f.f(u, p) : f.f(u, p, f.t)
189182

183+
# This is solving the langrange multiplier formulation. This is more accurate but at the
184+
# same time significantly more expensive.
185+
@concrete mutable struct NonlinearSolveManifoldProjectionCache{iip}
186+
manifold
187+
p
188+
λ
189+
z
190+
191+
gu_cache
192+
nlcache
193+
194+
first_call::Bool
195+
J
196+
manifold_jacobian
197+
autodiff
198+
di_extras
199+
end
200+
201+
function SciMLBase.reinit!(
202+
cache::NonlinearSolveManifoldProjectionCache{iip}, u; p = cache.p) where {iip}
203+
if !cache.first_call || (cache.!== u || cache.p !== p)
204+
compute_manifold_jacobian!(cache.J, cache.manifold_jacobian, cache.autodiff,
205+
Val(iip), cache.manifold, cache.gu_cache, u, p, cache.di_extras)
206+
end
207+
cache.first_call = false
208+
cache.= u
209+
cache.p = p
210+
211+
cache.z[1:length(cache.λ)] .= false
212+
cache.z[(length(cache.λ) + 1):end] .= vec(u)
213+
SciMLBase.reinit!(cache.nlcache, cache.z; p = (u, cache.J, p))
214+
end
215+
216+
function init_manifold_projection_nonlinear_problem(
217+
IIP::Val{iip}, manifold, autodiff, manifold_jacobian, ũ, p, alg;
218+
resid_prototype = nothing, kwargs...) where {iip}
219+
if iip
220+
if resid_prototype !== nothing
221+
gu = similar(resid_prototype)
222+
λ = similar(resid_prototype)
223+
else
224+
@warn "`resid_prototype` not provided for in-place problem. Assuming size of \
225+
output is the same as input. This might be incorrect." maxlog=1
226+
gu = similar(ũ)
227+
λ = similar(ũ)
228+
end
229+
else
230+
gu = nothing
231+
λ = manifold(ũ, p)
232+
end
233+
234+
J, di_extras = setup_manifold_jacobian(manifold_jacobian, autodiff, IIP, manifold,
235+
gu, ũ, p)
236+
z = vcat(vec(λ), vec(ũ))
237+
238+
nlfunc = if iip
239+
let λlen = length(λ), λsz = size(λ), zsz = size(ũ)
240+
@views (resid, u, ps) -> begin
241+
ũ2, J2, p2 = ps
242+
λ2, z2 = u[1:λlen], u[(λlen + 1):end]
243+
manifold(reshape(resid[1:λlen], λsz), reshape(z2, zsz), p2)
244+
resid[(λlen + 1):end] .= z2 .- vec(ũ2) .+ vec(vec(J2' * λ2))
245+
end
246+
end
247+
else
248+
let λlen = length(λ), zsz = size(ũ)
249+
@views (u, ps) -> begin
250+
ũ2, J2, p2 = ps
251+
λ2, z2 = u[1:λlen], u[(λlen + 1):end]
252+
gz = vec(manifold(reshape(z2, zsz), p2))
253+
resid = z2 .- vec(ũ2) .+ vec(J2' * λ2)
254+
return vcat(gz, resid)
255+
end
256+
end
257+
end
258+
259+
nlprob = NonlinearProblem(NonlinearFunction{iip}(nlfunc), z, (ũ, J, p))
260+
nlcache = SciMLBase.init(nlprob, alg; kwargs...)
261+
262+
return NonlinearSolveManifoldProjectionCache{iip}(
263+
manifold, p, λ, z, ũ, gu, nlcache, true, J, manifold_jacobian, autodiff, di_extras)
264+
end
265+
266+
@views function SciMLBase.solve!(cache::NonlinearSolveManifoldProjectionCache{iip}) where {iip}
267+
sol = SciMLBase.solve!(cache.nlcache)
268+
(; u, retcode) = sol
269+
λ = reshape(u[1:length(cache.λ)], size(cache.λ))
270+
= reshape(u[(length(cache.λ) + 1):end], size(cache.ũ))
271+
return λ, ũ, retcode
272+
end
273+
190274
# This is the algorithm described in Hairer III.
191275
@concrete mutable struct SingleFactorizeManifoldProjectionCache{iip}
192276
manifold
@@ -225,7 +309,7 @@ default_abstol(::Type{T}) where {T} = real(oneunit(T)) * (eps(real(one(T))))^(4
225309

226310
function init_manifold_projection(IIP::Val{iip}, manifold, autodiff, manifold_jacobian, ũ,
227311
p; abstol = default_abstol(eltype(ũ)), maxiters = 1000,
228-
resid_prototype = nothing) where {iip}
312+
resid_prototype = nothing, kwargs...) where {iip}
229313
if iip
230314
if resid_prototype !== nothing
231315
gu = similar(resid_prototype)
@@ -309,6 +393,11 @@ function setup_manifold_jacobian(
309393
return J, di_extras
310394
end
311395

396+
function setup_manifold_jacobian(
397+
::Nothing, ::Nothing, ::Val{iip}, manifold, gu, ũ, p) where {iip}
398+
error("`autodiff` is set to `nothing` and analytic manifold jacobian is not provided.")
399+
end
400+
312401
function compute_manifold_jacobian!(J, manifold_jacobian, autodiff, ::Val{iip},
313402
manifold, gu, ũ, p, di_extras) where {iip}
314403
if iip
@@ -329,10 +418,6 @@ function compute_manifold_jacobian!(J, ::Nothing, autodiff, ::Val{iip}, manifold
329418
return J
330419
end
331420

332-
function setup_manifold_jacobian(::Nothing, ::Nothing, args...)
333-
error("`autodiff` is set to `nothing` and analytic manifold jacobian is not provided.")
334-
end
335-
336421
function safe_factorize!(A::AbstractMatrix)
337422
if issquare(A)
338423
fact = LinearAlgebra.cholesky(A; check = false)

test/manifold_tests.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,27 @@ solve(prob, Vern7(), callback = cb_t)
3232

3333
# autodiff=false
3434
cb_false = ManifoldProjection(
35-
g; nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2))
35+
g; nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2),
36+
autodiff = AutoFiniteDiff())
3637
solve(prob, Vern7(), callback = cb_false)
3738
sol = solve(prob, Vern7(), callback = cb_false)
3839
@test sol.u[end][1]^2 + sol.u[end][2]^2 2
3940

4041
cb_t_false = ManifoldProjection(g_t,
41-
nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2))
42+
nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2),
43+
autodiff = AutoFiniteDiff())
4244
solve(prob, Vern7(), callback = cb_t_false)
4345
sol_t = solve(prob, Vern7(), callback = cb_t_false)
4446
@test sol_t.u == sol.u && sol_t.t == sol.t
4547

4648
# test array partitions
49+
function f_ap!(du, u, p, t)
50+
du[1:2] .= u[3:4]
51+
du[3:4] .= u[1:2]
52+
end
53+
4754
u₀ = ArrayPartition(ones(2), ones(2))
48-
prob = ODEProblem(f, u₀, (0.0, 100.0))
55+
prob = ODEProblem(f_ap!, u₀, (0.0, 100.0))
4956

5057
sol = solve(prob, Vern7(), callback = cb)
5158
@test sol.u[end][1]^2 + sol.u[end][2]^2 2
@@ -71,6 +78,12 @@ sol = solve(prob, Vern7(), callback = cb_unsat)
7178
@test !SciMLBase.successful_retcode(sol)
7279
@test last(sol.t) != 100.0
7380

81+
cb_unsat = ManifoldProjection(
82+
g_unsat; resid_prototype = zeros(2), autodiff = AutoForwardDiff(), nlsolve = NewtonRaphson())
83+
sol = solve(prob, Vern7(), callback = cb_unsat)
84+
@test !SciMLBase.successful_retcode(sol)
85+
@test last(sol.t) != 100.0
86+
7487
# Tests for OOP Manifold Projection
7588
function g_oop(u, p)
7689
return [u[2]^2 + u[1]^2 - 2
@@ -98,20 +111,22 @@ solve(prob, Vern7(), callback = cb_t)
98111

99112
# autodiff=false
100113
cb_false = ManifoldProjection(
101-
g_oop; nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), isinplace = Val(false))
114+
g_oop; nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), autodiff = AutoFiniteDiff())
102115
solve(prob, Vern7(), callback = cb_false)
103116
sol = solve(prob, Vern7(), callback = cb_false)
104117
@test sol.u[end][1]^2 + sol.u[end][2]^2 2
105118

106119
cb_t_false = ManifoldProjection(g_oop_t,
107-
nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), isinplace = Val(false))
120+
nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), autodiff = AutoFiniteDiff())
108121
solve(prob, Vern7(), callback = cb_t_false)
109122
sol_t = solve(prob, Vern7(), callback = cb_t_false)
110123
@test sol_t.u == sol.u && sol_t.t == sol.t
111124

112125
# test array partitions
126+
f_ap(u, p, t) = ArrayPartition(u[3:4], u[1:2])
127+
113128
u₀ = ArrayPartition(ones(2), ones(2))
114-
prob = ODEProblem(f, u₀, (0.0, 100.0))
129+
prob = ODEProblem(f_ap, u₀, (0.0, 100.0))
115130

116131
sol = solve(prob, Vern7(), callback = cb)
117132
@test sol.u[end][1]^2 + sol.u[end][2]^2 2

0 commit comments

Comments
 (0)