Skip to content

Commit 5b14b20

Browse files
Merge pull request #297 from avik-pal/ap/inplace
Reduce unnecessary allocations and reuse code
2 parents a39130b + 3b52a5a commit 5b14b20

20 files changed

+1138
-1712
lines changed

Project.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.9.0"
4+
version = "2.10.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -16,6 +16,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1616
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
19+
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1920
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2021
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2122
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
@@ -25,7 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2526
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2627
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2728
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
28-
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
29+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2930
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3031

3132
[weakdeps]
@@ -42,8 +43,8 @@ NonlinearSolveZygoteExt = "Zygote"
4243

4344
[compat]
4445
ADTypes = "0.2"
45-
ArrayInterface = "6.0.24, 7"
4646
Aqua = "0.8"
47+
ArrayInterface = "6.0.24, 7"
4748
BandedMatrices = "1"
4849
BenchmarkTools = "1"
4950
ConcreteStructs = "0.2"
@@ -59,6 +60,7 @@ LeastSquaresOptim = "0.8"
5960
LineSearches = "7"
6061
LinearAlgebra = "<0.0.1, 1"
6162
LinearSolve = "2.12"
63+
MaybeInplace = "0.1"
6264
NaNMath = "1"
6365
NonlinearProblemLibrary = "0.1"
6466
Pkg = "1"
@@ -72,9 +74,8 @@ SciMLBase = "2.9"
7274
SciMLOperators = "0.3"
7375
SimpleNonlinearSolve = "0.1.23"
7476
SparseArrays = "<0.0.1, 1"
75-
SparseDiffTools = "2.12"
77+
SparseDiffTools = "2.14"
7678
StaticArrays = "1"
77-
StaticArraysCore = "1.4"
7879
Symbolics = "5"
7980
Test = "1"
8081
UnPack = "1.0"

src/NonlinearSolve.jl

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,26 @@ import Reexport: @reexport
88
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
99

1010
@recompile_invalidations begin
11-
using DiffEqBase,
12-
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
13-
SparseDiffTools
14-
using FastBroadcast: @..
15-
import ArrayInterface: restructure
11+
using ADTypes, DiffEqBase, LazyArrays, LineSearches, LinearAlgebra, LinearSolve, Printf,
12+
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, StaticArrays
1613

1714
import ADTypes: AbstractFiniteDifferencesMode
18-
import ArrayInterface: undefmatrix,
15+
import ArrayInterface: undefmatrix, restructure, can_setindex,
1916
matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing
2017
import ConcreteStructs: @concrete
2118
import EnumX: @enumx
19+
import FastBroadcast: @..
20+
import FiniteDiff
2221
import ForwardDiff
2322
import ForwardDiff: Dual
2423
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
24+
import MaybeInplace: setindex_trait, @bb, CanSetindex, CannotSetindex
2525
import RecursiveArrayTools: ArrayPartition,
2626
AbstractVectorOfArray, recursivecopy!, recursivefill!
2727
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
2828
import SciMLOperators: FunctionOperator
29-
import StaticArraysCore: StaticArray, SVector, SArray, MArray
29+
import StaticArrays: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
3030
import UnPack: @unpack
31-
32-
using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
3331
end
3432

3533
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
@@ -52,16 +50,65 @@ abstract type AbstractNonlinearSolveCache{iip} end
5250

5351
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
5452

53+
function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(cache);
54+
p = cache.p, abstol = cache.abstol, reltol = cache.reltol,
55+
maxiters = cache.maxiters, alias_u0 = false, termination_condition = missing,
56+
kwargs...) where {iip}
57+
cache.p = p
58+
if iip
59+
recursivecopy!(get_u(cache), u0)
60+
cache.f(get_fu(cache), get_u(cache), p)
61+
else
62+
cache.u = __maybe_unaliased(u0, alias_u0)
63+
set_fu!(cache, cache.f(cache.u, p))
64+
end
65+
66+
reset!(cache.trace)
67+
68+
# Some algorithms store multiple termination caches
69+
if hasfield(typeof(cache), :tc_cache)
70+
# TODO: We need an efficient way to reset this upstream
71+
tc = termination_condition === missing ? get_termination_mode(cache.tc_cache) :
72+
termination_condition
73+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, get_fu(cache),
74+
get_u(cache), tc)
75+
cache.tc_cache = tc_cache
76+
end
77+
78+
if hasfield(typeof(cache), :ls_cache)
79+
# TODO: A more efficient way to do this
80+
cache.ls_cache = init_linesearch_cache(cache.alg.linesearch, cache.f,
81+
get_u(cache), p, get_fu(cache), Val(iip))
82+
end
83+
84+
hasfield(typeof(cache), :uf) && (cache.uf.p = p)
85+
86+
cache.abstol = abstol
87+
cache.reltol = reltol
88+
cache.maxiters = maxiters
89+
cache.stats.nf = 1
90+
cache.stats.nsteps = 1
91+
cache.force_stop = false
92+
cache.retcode = ReturnCode.Default
93+
94+
__reinit_internal!(cache; u0, p, abstol, reltol, maxiters, alias_u0,
95+
termination_condition, kwargs...)
96+
97+
return cache
98+
end
99+
100+
__reinit_internal!(::AbstractNonlinearSolveCache; kwargs...) = nothing
101+
55102
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
56103
str = "$(nameof(typeof(alg)))("
57104
modifiers = String[]
58-
if _getproperty(alg, Val(:ad)) !== nothing
105+
if __getproperty(alg, Val(:ad)) !== nothing
59106
push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
60107
end
61-
if _getproperty(alg, Val(:linsolve)) !== nothing
108+
if __getproperty(alg, Val(:linsolve)) !== nothing
62109
push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
63110
end
64-
if _getproperty(alg, Val(:linesearch)) !== nothing
111+
if __getproperty(alg, Val(:linesearch)) !== nothing
65112
ls = alg.linesearch
66113
if ls isa LineSearch
67114
ls.method !== nothing &&
@@ -70,7 +117,7 @@ function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
70117
push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
71118
end
72119
end
73-
if _getproperty(alg, Val(:radius_update_scheme)) !== nothing
120+
if __getproperty(alg, Val(:radius_update_scheme)) !== nothing
74121
push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
75122
end
76123
str = str * join(modifiers, ", ")
@@ -87,8 +134,9 @@ end
87134
function not_terminated(cache::AbstractNonlinearSolveCache)
88135
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
89136
end
90-
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
91-
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)
137+
138+
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu
139+
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu)
92140
get_u(cache::AbstractNonlinearSolveCache) = cache.u
93141
SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)
94142

@@ -107,7 +155,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
107155
end
108156
end
109157

110-
trace = _getproperty(cache, Val{:trace}())
158+
trace = __getproperty(cache, Val{:trace}())
111159
if trace !== nothing
112160
update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
113161
nothing, nothing; last = Val(true))

src/broyden.jl

Lines changed: 34 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,14 @@ end
3131
f
3232
alg
3333
u
34-
u_prev
34+
u_cache
3535
du
3636
fu
37-
fu2
37+
fu_cache
3838
dfu
3939
p
4040
J⁻¹
41-
J⁻¹₂
42-
J⁻¹df
41+
J⁻¹dfu
4342
force_stop::Bool
4443
resets::Int
4544
max_resets::Int
@@ -57,144 +56,77 @@ end
5756
trace
5857
end
5958

60-
get_fu(cache::GeneralBroydenCache) = cache.fu
61-
set_fu!(cache::GeneralBroydenCache, fu) = (cache.fu = fu)
62-
6359
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
6460
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
6561
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
6662
kwargs...) where {uType, iip, F}
6763
@unpack f, u0, p = prob
68-
u = alias_u0 ? u0 : deepcopy(u0)
64+
u = __maybe_unaliased(u0, alias_u0)
6965
fu = evaluate_f(prob, u)
70-
du = _mutable_zero(u)
66+
@bb du = copy(u)
7167
J⁻¹ = __init_identity_jacobian(u, fu)
7268
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
7369
alg.reset_tolerance
7470
reset_check = x -> abs(x) reset_tolerance
7571

72+
@bb u_cache = copy(u)
73+
@bb fu_cache = copy(fu)
74+
@bb dfu = similar(fu)
75+
@bb J⁻¹dfu = similar(u)
76+
7677
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
7778
termination_condition)
7879
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
7980
kwargs...)
8081

81-
return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
82-
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
83-
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
84-
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
82+
return GeneralBroydenCache{iip}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p,
83+
J⁻¹, J⁻¹dfu, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default,
84+
abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
8585
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
8686
end
8787

88-
function perform_step!(cache::GeneralBroydenCache{true})
89-
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache
90-
T = eltype(u)
91-
92-
mul!(_vec(du), J⁻¹, _vec(fu))
93-
α = perform_linesearch!(cache.ls_cache, u, du)
94-
_axpy!(-α, du, u)
95-
f(fu2, u, p)
96-
97-
update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
98-
get_fu(cache), J⁻¹, du, α)
99-
100-
check_and_update!(cache, fu2, u, u_prev)
101-
cache.stats.nf += 1
102-
103-
cache.force_stop && return nothing
104-
105-
# Update the inverse jacobian
106-
dfu .= fu2 .- fu
107-
108-
if all(cache.reset_check, du) || all(cache.reset_check, dfu)
109-
if cache.resets cache.max_resets
110-
cache.retcode = ReturnCode.ConvergenceFailure
111-
cache.force_stop = true
112-
return nothing
113-
end
114-
fill!(J⁻¹, 0)
115-
J⁻¹[diagind(J⁻¹)] .= T(1)
116-
cache.resets += 1
117-
else
118-
du .*= -1
119-
mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu))
120-
mul!(J⁻¹₂, _vec(du)', J⁻¹)
121-
denom = dot(du, J⁻¹df)
122-
du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
123-
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
124-
end
125-
fu .= fu2
126-
@. u_prev = u
127-
128-
return nothing
129-
end
130-
131-
function perform_step!(cache::GeneralBroydenCache{false})
132-
@unpack f, p = cache
133-
88+
function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
13489
T = eltype(cache.u)
13590

136-
cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu))
91+
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
13792
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
138-
cache.u = cache.u .- α * cache.du
139-
cache.fu2 = f(cache.u, p)
93+
@bb axpy!(-α, cache.du, cache.u)
14094

141-
update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
142-
get_fu(cache), cache.J⁻¹, cache.du, α)
95+
evaluate_f(cache, cache.u, cache.p)
14396

144-
check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
145-
cache.stats.nf += 1
97+
update_trace!(cache, α)
98+
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
14699

147100
cache.force_stop && return nothing
148101

149102
# Update the inverse jacobian
150-
cache.dfu = cache.fu2 .- cache.fu
103+
@bb @. cache.dfu = cache.fu - cache.fu_cache
104+
151105
if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
152106
if cache.resets cache.max_resets
153107
cache.retcode = ReturnCode.ConvergenceFailure
154108
cache.force_stop = true
155109
return nothing
156110
end
157-
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
111+
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
158112
cache.resets += 1
159113
else
160-
cache.du = -cache.du
161-
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
162-
cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹
163-
denom = dot(cache.du, cache.J⁻¹df)
164-
cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
165-
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
114+
@bb cache.du .*= -1
115+
@bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu)
116+
@bb cache.u_cache = transpose(cache.J⁻¹) × vec(cache.du)
117+
denom = dot(cache.du, cache.J⁻¹dfu)
118+
@bb @. cache.du = (cache.du - cache.J⁻¹dfu) / ifelse(iszero(denom), T(1e-5), denom)
119+
@bb cache.J⁻¹ += vec(cache.du) × transpose(_vec(cache.u_cache))
166120
end
167-
cache.fu = cache.fu2
168-
cache.u_prev = @. cache.u
121+
122+
@bb copyto!(cache.fu_cache, cache.fu)
123+
@bb copyto!(cache.u_cache, cache.u)
169124

170125
return nothing
171126
end
172127

173-
function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
174-
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
175-
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
176-
cache.p = p
177-
if iip
178-
recursivecopy!(cache.u, u0)
179-
cache.f(cache.fu, cache.u, p)
180-
else
181-
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
182-
cache.u = u0
183-
cache.fu = cache.f(cache.u, p)
184-
end
185-
186-
reset!(cache.trace)
187-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
188-
termination_condition)
189-
190-
cache.abstol = abstol
191-
cache.reltol = reltol
192-
cache.tc_cache = tc_cache
193-
cache.maxiters = maxiters
194-
cache.stats.nf = 1
195-
cache.stats.nsteps = 1
128+
function __reinit_internal!(cache::GeneralBroydenCache; kwargs...)
129+
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
196130
cache.resets = 0
197-
cache.force_stop = false
198-
cache.retcode = ReturnCode.Default
199-
return cache
131+
return nothing
200132
end

0 commit comments

Comments
 (0)