|
| 1 | +@enumx NonlinearSafeTerminationReturnCode begin |
| 2 | + Success |
| 3 | + Default |
| 4 | + PatienceTermination |
| 5 | + ProtectiveTermination |
| 6 | + Failure |
| 7 | +end |
| 8 | + |
| 9 | +abstract type AbstractNonlinearTerminationMode end |
| 10 | +abstract type AbstractSafeNonlinearTerminationMode <: AbstractNonlinearTerminationMode end |
| 11 | +abstract type AbstractSafeBestNonlinearTerminationMode <: |
| 12 | + AbstractSafeNonlinearTerminationMode end |
| 13 | + |
| 14 | +# TODO: Add a mode where the user can pass in custom termination criteria function |
| 15 | +for mode in (:SteadyStateDiffEqTerminationMode, :SimpleNonlinearSolveTerminationMode, |
| 16 | + :NormTerminationMode, :RelTerminationMode, :RelNormTerminationMode, :AbsTerminationMode, |
| 17 | + :AbsNormTerminationMode) |
| 18 | + @eval begin |
| 19 | + struct $(mode) <: AbstractNonlinearTerminationMode end |
| 20 | + end |
| 21 | +end |
| 22 | + |
| 23 | +for mode in (:RelSafeTerminationMode, :AbsSafeTerminationMode) |
| 24 | + @eval begin |
| 25 | + Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeNonlinearTerminationMode |
| 26 | + protective_threshold::T1 = 1000 |
| 27 | + patience_steps::Int = 30 |
| 28 | + patience_objective_multiplier::T2 = 3 |
| 29 | + min_max_factor::T3 = 1.3 |
| 30 | + end |
| 31 | + end |
| 32 | +end |
| 33 | + |
| 34 | +for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode) |
| 35 | + @eval begin |
| 36 | + Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode |
| 37 | + protective_threshold::T1 = 1000 |
| 38 | + patience_steps::Int = 30 |
| 39 | + patience_objective_multiplier::T2 = 3 |
| 40 | + min_max_factor::T3 = 1.3 |
| 41 | + end |
| 42 | + end |
| 43 | +end |
| 44 | + |
| 45 | +mutable struct NonlinearTerminationModeCache{uType, T, |
| 46 | + M <: AbstractNonlinearTerminationMode, I, OT} |
| 47 | + u::uType |
| 48 | + retcode::NonlinearSafeTerminationReturnCode.T |
| 49 | + abstol::T |
| 50 | + reltol::T |
| 51 | + best_objective_value::T |
| 52 | + mode::M |
| 53 | + initial_objective::I |
| 54 | + objectives_trace::OT |
| 55 | + nsteps::Int |
| 56 | +end |
| 57 | + |
| 58 | +function __update_u!!(cache::NonlinearTerminationModeCache, u) |
| 59 | + cache.u === nothing && return |
| 60 | + if ArrayInterface.can_setindex(cache.u) |
| 61 | + copyto!(cache.u, u) |
| 62 | + else |
| 63 | + cache.u = u |
| 64 | + end |
| 65 | +end |
| 66 | + |
| 67 | +__cvt_real(::Type{T}, ::Nothing) where {T} = nothing |
| 68 | +__cvt_real(::Type{T}, x) where {T} = real(T(x)) |
| 69 | + |
| 70 | +_get_tolerance(η, ::Type{T}) where {T} = __cvt_real(T, η) |
| 71 | +function _get_tolerance(::Nothing, ::Type{T}) where {T} |
| 72 | + η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) |
| 73 | + return _get_tolerance(η, T) |
| 74 | +end |
| 75 | + |
| 76 | +function SciMLBase.init(u::AbstractArray{T}, mode::AbstractNonlinearTerminationMode; |
| 77 | + abstol = nothing, reltol = nothing, kwargs...) where {T} |
| 78 | + abstol = _get_tolerance(abstol, T) |
| 79 | + reltol = _get_tolerance(reltol, T) |
| 80 | + best_value = __cvt_real(T, Inf) |
| 81 | + TT = typeof(abstol) |
| 82 | + u_ = mode isa AbstractSafeBestNonlinearTerminationMode ? |
| 83 | + (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing |
| 84 | + if mode isa AbstractSafeNonlinearTerminationMode |
| 85 | + initial_objective = TT(0) |
| 86 | + objectives_trace = Vector{TT}(undef, mode.patience_steps) |
| 87 | + else |
| 88 | + initial_objective = nothing |
| 89 | + objectives_trace = nothing |
| 90 | + end |
| 91 | + return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode), |
| 92 | + typeof(initial_objective), typeof(objectives_trace)}(u_, |
| 93 | + NonlinearSafeTerminationReturnCode.Default, abstol, reltol, best_value, mode, |
| 94 | + initial_objective, objectives_trace, 0) |
| 95 | +end |
| 96 | + |
| 97 | +# This dispatch is needed based on how Terminating Callback works! |
| 98 | +# This intentially drops the `abstol` and `reltol` arguments |
| 99 | +function (cache::NonlinearTerminationModeCache)(integrator, _, _, min_t) |
| 100 | + return cache(cache.mode, get_du(integrator), integrator.u, integrator.uprev) |
| 101 | +end |
| 102 | +(cache::NonlinearTerminationModeCache)(du, u, uprev) = cache(cache.mode, du, u, uprev) |
| 103 | + |
| 104 | +function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du, |
| 105 | + u, uprev) |
| 106 | + return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) |
| 107 | +end |
| 108 | + |
| 109 | +function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode, |
| 110 | + du, u, uprev) |
| 111 | + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode |
| 112 | + objective = NLSOLVE_DEFAULT_NORM(du) |
| 113 | + criteria = cache.abstol |
| 114 | + else |
| 115 | + objective = NLSOLVE_DEFAULT_NORM(du) / |
| 116 | + (NLSOLVE_DEFAULT_NORM(du .+ u) + eps(cache.abstol)) |
| 117 | + criteria = cache.reltol |
| 118 | + end |
| 119 | + |
| 120 | + # Check if best solution |
| 121 | + if mode isa AbstractSafeBestNonlinearTerminationMode && |
| 122 | + objective < cache.best_objective_value |
| 123 | + cache.best_objective_value = objective |
| 124 | + __update_u!!(cache, u) |
| 125 | + end |
| 126 | + |
| 127 | + # Main Termination Condition |
| 128 | + if objective ≤ criteria |
| 129 | + cache.retcode = NonlinearSafeTerminationReturnCode.Success |
| 130 | + return true |
| 131 | + end |
| 132 | + |
| 133 | + # Terminate if there has been no improvement for the last `patience_steps` |
| 134 | + cache.nsteps += 1 |
| 135 | + cache.nsteps == 1 && (cache.initial_objective = objective) |
| 136 | + cache.objectives_trace[mod1(cache.nsteps, length(cache.objectives_trace))] = objective |
| 137 | + |
| 138 | + if objective ≤ cache.mode.patience_objective_multiplier * criteria |
| 139 | + if cache.nsteps ≥ cache.mode.patience_steps |
| 140 | + if cache.nsteps < length(cache.objectives_trace) |
| 141 | + min_obj, max_obj = extrema(@view(cache.objectives_trace[1:cache.nsteps])) |
| 142 | + else |
| 143 | + min_obj, max_obj = extrema(cache.objectives_trace) |
| 144 | + end |
| 145 | + if min_obj < cache.mode.min_max_factor * max_obj |
| 146 | + cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination |
| 147 | + return true |
| 148 | + end |
| 149 | + end |
| 150 | + end |
| 151 | + |
| 152 | + # Protective Break |
| 153 | + if objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du) |
| 154 | + cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination |
| 155 | + return true |
| 156 | + end |
| 157 | + |
| 158 | + cache.retcode = NonlinearSafeTerminationReturnCode.Failure |
| 159 | + return false |
| 160 | +end |
| 161 | + |
| 162 | +function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, |
| 163 | + reltol) |
| 164 | + return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) |
| 165 | +end |
| 166 | +function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, |
| 167 | + reltol) |
| 168 | + return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) || |
| 169 | + isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol) |
| 170 | +end |
| 171 | +function check_convergence(::NormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) |
| 172 | + du_norm = NLSOLVE_DEFAULT_NORM(duₙ) |
| 173 | + return du_norm ≤ abstol || du_norm ≤ reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ) |
| 174 | +end |
| 175 | +function check_convergence(::RelNormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) |
| 176 | + return all(abs.(duₙ) .≤ reltol .* abs.(uₙ)) |
| 177 | +end |
| 178 | +function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, |
| 179 | + RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) |
| 180 | + return NLSOLVE_DEFAULT_NORM(duₙ) ≤ reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ) |
| 181 | +end |
| 182 | +function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) |
| 183 | + return all(abs.(duₙ) .≤ abstol) |
| 184 | +end |
| 185 | +function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode, |
| 186 | + AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) |
| 187 | + return NLSOLVE_DEFAULT_NORM(duₙ) ≤ abstol |
| 188 | +end |
| 189 | + |
| 190 | +# NOTE: Deprecate the following API eventually. This API leads to quite a bit of type |
| 191 | +# instability |
1 | 192 | @enumx NLSolveSafeTerminationReturnCode begin
|
2 | 193 | Success
|
3 | 194 | PatienceTermination
|
|
0 commit comments