|
31 | 31 | f |
32 | 32 | alg |
33 | 33 | u |
34 | | - u_prev |
| 34 | + u_cache |
35 | 35 | du |
36 | 36 | fu |
37 | | - fu2 |
| 37 | + fu_cache |
38 | 38 | dfu |
39 | 39 | p |
40 | 40 | J⁻¹ |
41 | | - J⁻¹₂ |
42 | | - J⁻¹df |
| 41 | + J⁻¹dfu |
43 | 42 | force_stop::Bool |
44 | 43 | resets::Int |
45 | 44 | max_resets::Int |
|
57 | 56 | trace |
58 | 57 | end |
59 | 58 |
|
60 | | -get_fu(cache::GeneralBroydenCache) = cache.fu |
61 | | -set_fu!(cache::GeneralBroydenCache, fu) = (cache.fu = fu) |
62 | | - |
63 | 59 | function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...; |
64 | 60 | alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, |
65 | 61 | termination_condition = nothing, internalnorm::F = DEFAULT_NORM, |
66 | 62 | kwargs...) where {uType, iip, F} |
67 | 63 | @unpack f, u0, p = prob |
68 | | - u = alias_u0 ? u0 : deepcopy(u0) |
| 64 | + u = __maybe_unaliased(u0, alias_u0) |
69 | 65 | fu = evaluate_f(prob, u) |
70 | | - du = _mutable_zero(u) |
| 66 | + @bb du = copy(u) |
71 | 67 | J⁻¹ = __init_identity_jacobian(u, fu) |
72 | 68 | reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) : |
73 | 69 | alg.reset_tolerance |
74 | 70 | reset_check = x -> abs(x) ≤ reset_tolerance |
75 | 71 |
|
| 72 | + @bb u_cache = copy(u) |
| 73 | + @bb fu_cache = copy(fu) |
| 74 | + @bb dfu = similar(fu) |
| 75 | + @bb J⁻¹dfu = similar(u) |
| 76 | + |
76 | 77 | abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u, |
77 | 78 | termination_condition) |
78 | 79 | trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true), |
79 | 80 | kwargs...) |
80 | 81 |
|
81 | | - return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu), |
82 | | - zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0, |
83 | | - alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, |
84 | | - reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), |
| 82 | + return GeneralBroydenCache{iip}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p, |
| 83 | + J⁻¹, J⁻¹dfu, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default, |
| 84 | + abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), |
85 | 85 | init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace) |
86 | 86 | end |
87 | 87 |
|
88 | | -function perform_step!(cache::GeneralBroydenCache{true}) |
89 | | - @unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache |
90 | | - T = eltype(u) |
91 | | - |
92 | | - mul!(_vec(du), J⁻¹, _vec(fu)) |
93 | | - α = perform_linesearch!(cache.ls_cache, u, du) |
94 | | - _axpy!(-α, du, u) |
95 | | - f(fu2, u, p) |
96 | | - |
97 | | - update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache), |
98 | | - get_fu(cache), J⁻¹, du, α) |
99 | | - |
100 | | - check_and_update!(cache, fu2, u, u_prev) |
101 | | - cache.stats.nf += 1 |
102 | | - |
103 | | - cache.force_stop && return nothing |
104 | | - |
105 | | - # Update the inverse jacobian |
106 | | - dfu .= fu2 .- fu |
107 | | - |
108 | | - if all(cache.reset_check, du) || all(cache.reset_check, dfu) |
109 | | - if cache.resets ≥ cache.max_resets |
110 | | - cache.retcode = ReturnCode.ConvergenceFailure |
111 | | - cache.force_stop = true |
112 | | - return nothing |
113 | | - end |
114 | | - fill!(J⁻¹, 0) |
115 | | - J⁻¹[diagind(J⁻¹)] .= T(1) |
116 | | - cache.resets += 1 |
117 | | - else |
118 | | - du .*= -1 |
119 | | - mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu)) |
120 | | - mul!(J⁻¹₂, _vec(du)', J⁻¹) |
121 | | - denom = dot(du, J⁻¹df) |
122 | | - du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom) |
123 | | - mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1) |
124 | | - end |
125 | | - fu .= fu2 |
126 | | - @. u_prev = u |
127 | | - |
128 | | - return nothing |
129 | | -end |
130 | | - |
131 | | -function perform_step!(cache::GeneralBroydenCache{false}) |
132 | | - @unpack f, p = cache |
133 | | - |
| 88 | +function perform_step!(cache::GeneralBroydenCache{iip}) where {iip} |
134 | 89 | T = eltype(cache.u) |
135 | 90 |
|
136 | | - cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu)) |
| 91 | + @bb cache.du = cache.J⁻¹ × vec(cache.fu) |
137 | 92 | α = perform_linesearch!(cache.ls_cache, cache.u, cache.du) |
138 | | - cache.u = cache.u .- α * cache.du |
139 | | - cache.fu2 = f(cache.u, p) |
| 93 | + @bb axpy!(-α, cache.du, cache.u) |
140 | 94 |
|
141 | | - update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache), |
142 | | - get_fu(cache), cache.J⁻¹, cache.du, α) |
| 95 | + evaluate_f(cache, cache.u, cache.p) |
143 | 96 |
|
144 | | - check_and_update!(cache, cache.fu2, cache.u, cache.u_prev) |
145 | | - cache.stats.nf += 1 |
| 97 | + update_trace!(cache, α) |
| 98 | + check_and_update!(cache, cache.fu, cache.u, cache.u_cache) |
146 | 99 |
|
147 | 100 | cache.force_stop && return nothing |
148 | 101 |
|
149 | 102 | # Update the inverse jacobian |
150 | | - cache.dfu = cache.fu2 .- cache.fu |
| 103 | + @bb @. cache.dfu = cache.fu - cache.fu_cache |
| 104 | + |
151 | 105 | if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu) |
152 | 106 | if cache.resets ≥ cache.max_resets |
153 | 107 | cache.retcode = ReturnCode.ConvergenceFailure |
154 | 108 | cache.force_stop = true |
155 | 109 | return nothing |
156 | 110 | end |
157 | | - cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu) |
| 111 | + cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹) |
158 | 112 | cache.resets += 1 |
159 | 113 | else |
160 | | - cache.du = -cache.du |
161 | | - cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu)) |
162 | | - cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹ |
163 | | - denom = dot(cache.du, cache.J⁻¹df) |
164 | | - cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom) |
165 | | - cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂ |
| 114 | + @bb cache.du .*= -1 |
| 115 | + @bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu) |
| 116 | + @bb cache.u_cache = transpose(cache.J⁻¹) × vec(cache.du) |
| 117 | + denom = dot(cache.du, cache.J⁻¹dfu) |
| 118 | + @bb @. cache.du = (cache.du - cache.J⁻¹dfu) / ifelse(iszero(denom), T(1e-5), denom) |
| 119 | + @bb cache.J⁻¹ += vec(cache.du) × transpose(_vec(cache.u_cache)) |
166 | 120 | end |
167 | | - cache.fu = cache.fu2 |
168 | | - cache.u_prev = @. cache.u |
| 121 | + |
| 122 | + @bb copyto!(cache.fu_cache, cache.fu) |
| 123 | + @bb copyto!(cache.u_cache, cache.u) |
169 | 124 |
|
170 | 125 | return nothing |
171 | 126 | end |
172 | 127 |
|
173 | | -function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p, |
174 | | - abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters, |
175 | | - termination_condition = get_termination_mode(cache.tc_cache)) where {iip} |
176 | | - cache.p = p |
177 | | - if iip |
178 | | - recursivecopy!(cache.u, u0) |
179 | | - cache.f(cache.fu, cache.u, p) |
180 | | - else |
181 | | - # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter |
182 | | - cache.u = u0 |
183 | | - cache.fu = cache.f(cache.u, p) |
184 | | - end |
185 | | - |
186 | | - reset!(cache.trace) |
187 | | - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u, |
188 | | - termination_condition) |
189 | | - |
190 | | - cache.abstol = abstol |
191 | | - cache.reltol = reltol |
192 | | - cache.tc_cache = tc_cache |
193 | | - cache.maxiters = maxiters |
194 | | - cache.stats.nf = 1 |
195 | | - cache.stats.nsteps = 1 |
| 128 | +function __reinit_internal!(cache::GeneralBroydenCache; kwargs...) |
| 129 | + cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹) |
196 | 130 | cache.resets = 0 |
197 | | - cache.force_stop = false |
198 | | - cache.retcode = ReturnCode.Default |
199 | | - return cache |
| 131 | + return nothing |
200 | 132 | end |
0 commit comments