|
31 | 31 | f
|
32 | 32 | alg
|
33 | 33 | u
|
34 |
| - u_prev |
| 34 | + u_cache |
35 | 35 | du
|
36 | 36 | fu
|
37 |
| - fu2 |
| 37 | + fu_cache |
38 | 38 | dfu
|
39 | 39 | p
|
40 | 40 | J⁻¹
|
41 |
| - J⁻¹₂ |
42 |
| - J⁻¹df |
| 41 | + J⁻¹dfu |
43 | 42 | force_stop::Bool
|
44 | 43 | resets::Int
|
45 | 44 | max_resets::Int
|
|
57 | 56 | trace
|
58 | 57 | end
|
59 | 58 |
|
60 |
| -get_fu(cache::GeneralBroydenCache) = cache.fu |
61 |
| -set_fu!(cache::GeneralBroydenCache, fu) = (cache.fu = fu) |
62 |
| - |
63 | 59 | function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
|
64 | 60 | alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
|
65 | 61 | termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
|
66 | 62 | kwargs...) where {uType, iip, F}
|
67 | 63 | @unpack f, u0, p = prob
|
68 |
| - u = alias_u0 ? u0 : deepcopy(u0) |
| 64 | + u = __maybe_unaliased(u0, alias_u0) |
69 | 65 | fu = evaluate_f(prob, u)
|
70 |
| - du = _mutable_zero(u) |
| 66 | + @bb du = copy(u) |
71 | 67 | J⁻¹ = __init_identity_jacobian(u, fu)
|
72 | 68 | reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
|
73 | 69 | alg.reset_tolerance
|
74 | 70 | reset_check = x -> abs(x) ≤ reset_tolerance
|
75 | 71 |
|
| 72 | + @bb u_cache = copy(u) |
| 73 | + @bb fu_cache = copy(fu) |
| 74 | + @bb dfu = similar(fu) |
| 75 | + @bb J⁻¹dfu = similar(u) |
| 76 | + |
76 | 77 | abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
|
77 | 78 | termination_condition)
|
78 | 79 | trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
|
79 | 80 | kwargs...)
|
80 | 81 |
|
81 |
| - return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu), |
82 |
| - zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0, |
83 |
| - alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, |
84 |
| - reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), |
| 82 | + return GeneralBroydenCache{iip}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p, |
| 83 | + J⁻¹, J⁻¹dfu, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default, |
| 84 | + abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), |
85 | 85 | init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
|
86 | 86 | end
|
87 | 87 |
|
88 |
| -function perform_step!(cache::GeneralBroydenCache{true}) |
89 |
| - @unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache |
90 |
| - T = eltype(u) |
91 |
| - |
92 |
| - mul!(_vec(du), J⁻¹, _vec(fu)) |
93 |
| - α = perform_linesearch!(cache.ls_cache, u, du) |
94 |
| - _axpy!(-α, du, u) |
95 |
| - f(fu2, u, p) |
96 |
| - |
97 |
| - update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache), |
98 |
| - get_fu(cache), J⁻¹, du, α) |
99 |
| - |
100 |
| - check_and_update!(cache, fu2, u, u_prev) |
101 |
| - cache.stats.nf += 1 |
102 |
| - |
103 |
| - cache.force_stop && return nothing |
104 |
| - |
105 |
| - # Update the inverse jacobian |
106 |
| - dfu .= fu2 .- fu |
107 |
| - |
108 |
| - if all(cache.reset_check, du) || all(cache.reset_check, dfu) |
109 |
| - if cache.resets ≥ cache.max_resets |
110 |
| - cache.retcode = ReturnCode.ConvergenceFailure |
111 |
| - cache.force_stop = true |
112 |
| - return nothing |
113 |
| - end |
114 |
| - fill!(J⁻¹, 0) |
115 |
| - J⁻¹[diagind(J⁻¹)] .= T(1) |
116 |
| - cache.resets += 1 |
117 |
| - else |
118 |
| - du .*= -1 |
119 |
| - mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu)) |
120 |
| - mul!(J⁻¹₂, _vec(du)', J⁻¹) |
121 |
| - denom = dot(du, J⁻¹df) |
122 |
| - du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom) |
123 |
| - mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1) |
124 |
| - end |
125 |
| - fu .= fu2 |
126 |
| - @. u_prev = u |
127 |
| - |
128 |
| - return nothing |
129 |
| -end |
130 |
| - |
131 |
| -function perform_step!(cache::GeneralBroydenCache{false}) |
132 |
| - @unpack f, p = cache |
133 |
| - |
| 88 | +function perform_step!(cache::GeneralBroydenCache{iip}) where {iip} |
134 | 89 | T = eltype(cache.u)
|
135 | 90 |
|
136 |
| - cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu)) |
| 91 | + @bb cache.du = cache.J⁻¹ × vec(cache.fu) |
137 | 92 | α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
|
138 |
| - cache.u = cache.u .- α * cache.du |
139 |
| - cache.fu2 = f(cache.u, p) |
| 93 | + @bb axpy!(-α, cache.du, cache.u) |
140 | 94 |
|
141 |
| - update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache), |
142 |
| - get_fu(cache), cache.J⁻¹, cache.du, α) |
| 95 | + evaluate_f(cache, cache.u, cache.p) |
143 | 96 |
|
144 |
| - check_and_update!(cache, cache.fu2, cache.u, cache.u_prev) |
145 |
| - cache.stats.nf += 1 |
| 97 | + update_trace!(cache, α) |
| 98 | + check_and_update!(cache, cache.fu, cache.u, cache.u_cache) |
146 | 99 |
|
147 | 100 | cache.force_stop && return nothing
|
148 | 101 |
|
149 | 102 | # Update the inverse jacobian
|
150 |
| - cache.dfu = cache.fu2 .- cache.fu |
| 103 | + @bb @. cache.dfu = cache.fu - cache.fu_cache |
| 104 | + |
151 | 105 | if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
|
152 | 106 | if cache.resets ≥ cache.max_resets
|
153 | 107 | cache.retcode = ReturnCode.ConvergenceFailure
|
154 | 108 | cache.force_stop = true
|
155 | 109 | return nothing
|
156 | 110 | end
|
157 |
| - cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu) |
| 111 | + cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹) |
158 | 112 | cache.resets += 1
|
159 | 113 | else
|
160 |
| - cache.du = -cache.du |
161 |
| - cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu)) |
162 |
| - cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹ |
163 |
| - denom = dot(cache.du, cache.J⁻¹df) |
164 |
| - cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom) |
165 |
| - cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂ |
| 114 | + @bb cache.du .*= -1 |
| 115 | + @bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu) |
| 116 | + @bb cache.u_cache = transpose(cache.J⁻¹) × vec(cache.du) |
| 117 | + denom = dot(cache.du, cache.J⁻¹dfu) |
| 118 | + @bb @. cache.du = (cache.du - cache.J⁻¹dfu) / ifelse(iszero(denom), T(1e-5), denom) |
| 119 | + @bb cache.J⁻¹ += vec(cache.du) × transpose(_vec(cache.u_cache)) |
166 | 120 | end
|
167 |
| - cache.fu = cache.fu2 |
168 |
| - cache.u_prev = @. cache.u |
| 121 | + |
| 122 | + @bb copyto!(cache.fu_cache, cache.fu) |
| 123 | + @bb copyto!(cache.u_cache, cache.u) |
169 | 124 |
|
170 | 125 | return nothing
|
171 | 126 | end
|
172 | 127 |
|
173 |
| -function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p, |
174 |
| - abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters, |
175 |
| - termination_condition = get_termination_mode(cache.tc_cache)) where {iip} |
176 |
| - cache.p = p |
177 |
| - if iip |
178 |
| - recursivecopy!(cache.u, u0) |
179 |
| - cache.f(cache.fu, cache.u, p) |
180 |
| - else |
181 |
| - # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter |
182 |
| - cache.u = u0 |
183 |
| - cache.fu = cache.f(cache.u, p) |
184 |
| - end |
185 |
| - |
186 |
| - reset!(cache.trace) |
187 |
| - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u, |
188 |
| - termination_condition) |
189 |
| - |
190 |
| - cache.abstol = abstol |
191 |
| - cache.reltol = reltol |
192 |
| - cache.tc_cache = tc_cache |
193 |
| - cache.maxiters = maxiters |
194 |
| - cache.stats.nf = 1 |
195 |
| - cache.stats.nsteps = 1 |
| 128 | +function __reinit_internal!(cache::GeneralBroydenCache; kwargs...) |
| 129 | + cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹) |
196 | 130 | cache.resets = 0
|
197 |
| - cache.force_stop = false |
198 |
| - cache.retcode = ReturnCode.Default |
199 |
| - return cache |
| 131 | + return nothing |
200 | 132 | end
|
0 commit comments