Skip to content

Commit e32d6e5

Browse files
committed
feat: SimpleLimitedMemoryBroyden impl
1 parent dd2c011 commit e32d6e5

File tree

2 files changed

+329
-2
lines changed

2 files changed

+329
-2
lines changed

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using PrecompileTools: @compile_workload, @setup_workload
1111
using Reexport: @reexport
1212
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
1313
using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode
14-
using StaticArraysCore: StaticArray, SVector
14+
using StaticArraysCore: StaticArray, SArray, SVector, MArray
1515

1616
# AD Dependencies
1717
using ADTypes: AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
@@ -107,7 +107,7 @@ export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
107107

108108
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
109109

110-
export SimpleBroyden, SimpleKlement
110+
export SimpleBroyden, SimpleKlement, SimpleLimitedMemoryBroyden
111111
export SimpleDFSane
112112
export SimpleGaussNewton, SimpleNewtonRaphson, SimpleTrustRegion
113113
export SimpleHalley
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,328 @@
1+
"""
2+
SimpleLimitedMemoryBroyden(; threshold::Union{Val, Int} = Val(27),
3+
linesearch = Val(false), alpha = nothing)
14
5+
A limited memory implementation of Broyden. This method applies the L-BFGS scheme to
6+
Broyden's method.
7+
8+
If the threshold is larger than the problem size, then this method will use `SimpleBroyden`.
9+
10+
### Keyword Arguments:
11+
12+
- `linesearch`: If `linesearch` is `Val(true)`, then we use the `LiFukushimaLineSearch`
13+
line search else no line search is used. For advanced customization of the line search,
14+
use `Broyden` from `NonlinearSolve.jl`.
15+
- `alpha`: Scale the initial jacobian initialization with `alpha`. If it is `nothing`, we
16+
will compute the scaling using `2 * norm(fu) / max(norm(u), true)`.
17+
18+
!!! warning
19+
20+
Currently `alpha` is only used for StaticArray problems. This will be fixed in the
21+
future.
22+
"""
23+
@concrete struct SimpleLimitedMemoryBroyden <: AbstractSimpleNonlinearSolveAlgorithm
24+
linesearch <: Union{Val{false}, Val{true}}
25+
threshold <: Val
26+
alpha
27+
end
28+
29+
function SimpleLimitedMemoryBroyden(; threshold::Union{Val, Int} = Val(27),
30+
linesearch::Union{Bool, Val{true}, Val{false}} = Val(false), alpha = nothing)
31+
linesearch = linesearch isa Bool ? Val(linesearch) : linesearch
32+
threshold = threshold isa Int ? Val(threshold) : threshold
33+
return SimpleLimitedMemoryBroyden(linesearch, threshold, alpha)
34+
end
35+
36+
function SciMLBase.__solve(
37+
prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
38+
args...; termination_condition = nothing, kwargs...)
39+
if prob.u0 isa SArray
40+
if termination_condition === nothing ||
41+
termination_condition isa AbsNormTerminationMode
42+
return internal_static_solve(
43+
prob, alg, args...; termination_condition, kwargs...)
44+
end
45+
@warn "Specifying `termination_condition = $(termination_condition)` for \
46+
`SimpleLimitedMemoryBroyden` with `SArray` is not non-allocating. Use \
47+
either `termination_condition = AbsNormTerminationMode()` or \
48+
`termination_condition = nothing`." maxlog=1
49+
end
50+
return internal_generic_solve(prob, alg, args...; termination_condition, kwargs...)
51+
end
52+
53+
@views function internal_generic_solve(
54+
prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
55+
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
56+
alias_u0 = false, termination_condition = nothing, kwargs...)
57+
x = Utils.maybe_unaliased(prob.u0, alias_u0)
58+
η = min(SciMLBase._unwrap_val(alg.threshold), maxiters)
59+
60+
# For scalar problems / if the threshold is larger than problem size just use Broyden
61+
if x isa Number || length(x) η
62+
return SciMLBase.__solve(prob, SimpleBroyden(; alg.linesearch), args...; abstol,
63+
reltol, maxiters, termination_condition, kwargs...)
64+
end
65+
66+
fx = Utils.get_fx(prob, x)
67+
68+
U, Vᵀ = init_low_rank_jacobian(x, fx, x isa StaticArray ? alg.threshold : Val(η))
69+
70+
abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
71+
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
72+
73+
@bb xo = copy(x)
74+
@bb δx = copy(fx)
75+
@bb δx .*= -1
76+
@bb fo = copy(fx)
77+
@bb δf = copy(fx)
78+
79+
@bb vᵀ_cache = copy(x)
80+
Tcache = lbroyden_threshold_cache(x, x isa StaticArray ? alg.threshold : Val(η))
81+
@bb mat_cache = copy(x)
82+
83+
if alg.linesearch === Val(true)
84+
ls_alg = LiFukushimaLineSearch(; nan_maxiters = nothing)
85+
ls_cache = init(prob, ls_alg, fx, x)
86+
else
87+
ls_cache = nothing
88+
end
89+
90+
for i in 1:maxiters
91+
if ls_cache === nothing
92+
α = true
93+
else
94+
ls_sol = solve!(ls_cache, xo, δx)
95+
α = ls_sol.step_size # Ignores the return code for now
96+
end
97+
98+
@bb @. x = xo + α * δx
99+
fx = Utils.eval_f(prob, fx, x)
100+
@bb @. δf = fx - fo
101+
102+
# Termination Checks
103+
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
104+
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
105+
106+
Uₚ = selectdim(U, 2, 1:min(η, i - 1))
107+
Vᵀₚ = selectdim(Vᵀ, 1, 1:min(η, i - 1))
108+
109+
vᵀ = rmatvec!!(vᵀ_cache, Tcache, Uₚ, Vᵀₚ, δx)
110+
mvec = matvec!!(mat_cache, Tcache, Uₚ, Vᵀₚ, δf)
111+
d = dot(vᵀ, δf)
112+
@bb @. δx = (δx - mvec) / d
113+
114+
selectdim(U, 2, mod1(i, η)) .= Utils.safe_vec(δx)
115+
selectdim(Vᵀ, 1, mod1(i, η)) .= Utils.safe_vec(vᵀ)
116+
117+
Uₚ = selectdim(U, 2, 1:min(η, i))
118+
Vᵀₚ = selectdim(Vᵀ, 1, 1:min(η, i))
119+
δx = matvec!!(δx, Tcache, Uₚ, Vᵀₚ, fx)
120+
@bb @. δx *= -1
121+
122+
@bb copyto!(xo, x)
123+
@bb copyto!(fo, fx)
124+
end
125+
126+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
127+
end
128+
129+
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
130+
# finicky, so we'll implement it separately from the generic version
131+
# Ignore termination_condition. Don't pass things into internal functions
132+
function internal_static_solve(
133+
prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
134+
args...; abstol = nothing, maxiters = 1000, kwargs...)
135+
x = prob.u0
136+
fx = Utils.get_fx(prob, x)
137+
138+
U, Vᵀ = init_low_rank_jacobian(vec(x), vec(fx), alg.threshold)
139+
140+
abstol = NonlinearSolveBase.get_tolerance(x, abstol, eltype(x))
141+
142+
xo, δx, fo, δf = x, -fx, fx, fx
143+
144+
if alg.linesearch === Val(true)
145+
ls_alg = LiFukushimaLineSearch(; nan_maxiters = nothing)
146+
ls_cache = init(prob, ls_alg, fx, x)
147+
else
148+
ls_cache = nothing
149+
end
150+
151+
T = promote_type(eltype(x), eltype(fx))
152+
if alg.alpha === nothing
153+
fx_norm = L2_NORM(fx)
154+
x_norm = L2_NORM(x)
155+
init_α = ifelse(fx_norm 1e-5, max(x_norm, T(true)) / (2 * fx_norm), T(true))
156+
else
157+
init_α = inv(alg.alpha)
158+
end
159+
160+
converged, res = internal_unrolled_lbroyden_initial_iterations(
161+
prob, xo, fo, δx, abstol, U, Vᵀ, alg.threshold, ls_cache, init_α)
162+
163+
converged && return SciMLBase.build_solution(
164+
prob, alg, res.x, res.fx; retcode = ReturnCode.Success)
165+
166+
xo, fo, δx = res.x, res.fx, res.δx
167+
168+
for i in 1:(maxiters - SciMLBase._unwrap_val(alg.threshold))
169+
if ls_cache === nothing
170+
α = true
171+
else
172+
ls_sol = solve!(ls_cache, xo, δx)
173+
α = ls_sol.step_size # Ignores the return code for now
174+
end
175+
176+
x = xo + α * δx
177+
fx = Utils.eval_f(prob, fx, x)
178+
δf = fx - fo
179+
180+
maximum(abs, fx) abstol &&
181+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
182+
183+
vᵀ = Utils.restructure(x, rmatvec!!(U, Vᵀ, vec(δx), init_α))
184+
mvec = Utils.restructure(x, matvec!!(U, Vᵀ, vec(δf), init_α))
185+
186+
d = dot(vᵀ, δf)
187+
δx = @. (δx - mvec) / d
188+
189+
U = Base.setindex(U, vec(δx), mod1(i, SciMLBase._unwrap_val(alg.threshold)))
190+
Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), mod1(i, SciMLBase._unwrap_val(alg.threshold)))
191+
192+
δx = -Utils.restructure(fx, matvec!!(U, Vᵀ, vec(fx), init_α))
193+
194+
xo, fo = x, fx
195+
end
196+
197+
return SciMLBase.build_solution(prob, alg, xo, fo; retcode = ReturnCode.MaxIters)
198+
end
199+
200+
@generated function internal_unrolled_lbroyden_initial_iterations(
201+
prob, xo, fo, δx, abstol, U, Vᵀ, ::Val{threshold},
202+
ls_cache, init_α) where {threshold}
203+
calls = []
204+
for i in 1:threshold
205+
static_idx, static_idx_p1 = Val(i - 1), Val(i)
206+
push!(calls, quote
207+
α = ls_cache === nothing ? true : ls_cache(xo, δx)
208+
x = xo .+ α .* δx
209+
fx = prob.f(x, prob.p)
210+
δf = fx - fo
211+
212+
maximum(abs, fx) abstol && return true, (; x, fx, δx)
213+
214+
Uₚ = first_n_getindex(U, $(static_idx))
215+
Vᵀₚ = first_n_getindex(Vᵀ, $(static_idx))
216+
217+
vᵀ = Utils.restructure(x, rmatvec!!(Uₚ, Vᵀₚ, vec(δx), init_α))
218+
mvec = Utils.restructure(x, matvec!!(Uₚ, Vᵀₚ, vec(δf), init_α))
219+
220+
d = dot(vᵀ, δf)
221+
δx = @. (δx - mvec) / d
222+
223+
U = Base.setindex(U, vec(δx), $(i))
224+
Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), $(i))
225+
226+
Uₚ = first_n_getindex(U, $(static_idx_p1))
227+
Vᵀₚ = first_n_getindex(Vᵀ, $(static_idx_p1))
228+
δx = -Utils.restructure(fx, matvec!!(Uₚ, Vᵀₚ, vec(fx), init_α))
229+
230+
x0, fo = x, fx
231+
end)
232+
end
233+
push!(calls, quote
234+
# Termination Check
235+
maximum(abs, fx) abstol && return true, (; x, fx, δx)
236+
237+
return false, (; x, fx, δx)
238+
end)
239+
return Expr(:block, calls...)
240+
end
241+
242+
function rmatvec!!(y, xᵀU, U, Vᵀ, x)
243+
# xᵀ × (-I + UVᵀ)
244+
η = size(U, 2)
245+
if η == 0
246+
@bb @. y = -x
247+
return y
248+
end
249+
x_ = vec(x)
250+
xᵀU_ = view(xᵀU, 1:η)
251+
@bb xᵀU_ = transpose(U) × x_
252+
@bb y = transpose(Vᵀ) × vec(xᵀU_)
253+
@bb @. y -= x
254+
return y
255+
end
256+
257+
rmatvec!!(::Nothing, Vᵀ, x, init_α) = -x .* init_α
258+
rmatvec!!(U, Vᵀ, x, init_α) = fast_mapTdot(fast_mapdot(x, U), Vᵀ) .- x .* init_α
259+
260+
function matvec!!(y, Vᵀx, U, Vᵀ, x)
261+
# (-I + UVᵀ) × x
262+
η = size(U, 2)
263+
if η == 0
264+
@bb @. y = -x
265+
return y
266+
end
267+
x_ = vec(x)
268+
Vᵀx_ = view(Vᵀx, 1:η)
269+
@bb Vᵀx_ = Vᵀ × x_
270+
@bb y = U × vec(Vᵀx_)
271+
@bb @. y -= x
272+
return y
273+
end
274+
275+
@inline matvec!!(::Nothing, Vᵀ, x, init_α) = -x .* init_α
276+
@inline matvec!!(U, Vᵀ, x, init_α) = fast_mapTdot(fast_mapdot(x, Vᵀ), U) .- x .* init_α
277+
278+
function fast_mapdot(x::SVector{S1}, Y::SVector{S2, <:SVector{S1}}) where {S1, S2}
279+
return map(Base.Fix1(dot, x), Y)
280+
end
281+
@generated function fast_mapTdot(
282+
x::SVector{S1}, Y::SVector{S1, <:SVector{S2}}) where {S1, S2}
283+
calls = []
284+
syms = [gensym("m$(i)") for i in 1:S1]
285+
for i in 1:S1
286+
push!(calls, :($(syms[i]) = x[$(i)] .* Y[$i]))
287+
end
288+
push!(calls, :(return .+($(syms...))))
289+
return Expr(:block, calls...)
290+
end
291+
292+
@generated function first_n_getindex(x::SVector{L, T}, ::Val{N}) where {L, T, N}
293+
@assert N L
294+
getcalls = ntuple(i -> :(x[$i]), N)
295+
N == 0 && return :(return nothing)
296+
return :(return SVector{$N, $T}(($(getcalls...))))
297+
end
298+
299+
lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = similar(x, threshold)
300+
function lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold}
301+
return zeros(MArray{Tuple{threshold}, eltype(x)})
302+
end
303+
lbroyden_threshold_cache(::SArray, ::Val{threshold}) where {threshold} = nothing
304+
305+
function init_low_rank_jacobian(u::StaticArray{S1, T1}, fu::StaticArray{S2, T2},
306+
::Val{threshold}) where {S1, S2, T1, T2, threshold}
307+
T = promote_type(T1, T2)
308+
fuSize, uSize = Size(fu), Size(u)
309+
Vᵀ = MArray{Tuple{threshold, prod(uSize)}, T}(undef)
310+
U = MArray{Tuple{prod(fuSize), threshold}, T}(undef)
311+
return U, Vᵀ
312+
end
313+
@generated function init_low_rank_jacobian(u::SVector{Lu, T1}, fu::SVector{Lfu, T2},
314+
::Val{threshold}) where {Lu, Lfu, T1, T2, threshold}
315+
T = promote_type(T1, T2)
316+
inner_inits_Vᵀ = [:(zeros(SVector{$Lu, $T})) for i in 1:threshold]
317+
inner_inits_U = [:(zeros(SVector{$Lfu, $T})) for i in 1:threshold]
318+
return quote
319+
Vᵀ = SVector($(inner_inits_Vᵀ...))
320+
U = SVector($(inner_inits_U...))
321+
return U, Vᵀ
322+
end
323+
end
324+
function init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
325+
Vᵀ = similar(u, threshold, length(u))
326+
U = similar(u, length(fu), threshold)
327+
return U, Vᵀ
328+
end

0 commit comments

Comments
 (0)