Skip to content

Commit 89c784a

Browse files
committed
Better initial objective
1 parent 9b2890f commit 89c784a

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

src/termination_conditions.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol
6161

6262
function __update_u!!(cache::NonlinearTerminationModeCache, u)
6363
cache.u === nothing && return
64-
if ArrayInterface.can_setindex(cache.u)
64+
if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u)
6565
copyto!(cache.u, u)
6666
else
6767
cache.u = u
@@ -77,21 +77,27 @@ function _get_tolerance(::Nothing, ::Type{T}) where {T}
7777
return _get_tolerance(η, T)
7878
end
7979

80-
function SciMLBase.init(u::Union{AbstractArray{T}, T},
80+
function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T},
8181
mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing,
8282
kwargs...) where {T <: Number}
8383
abstol = _get_tolerance(abstol, T)
8484
reltol = _get_tolerance(reltol, T)
85-
best_value = __cvt_real(T, Inf)
8685
TT = typeof(abstol)
8786
u_ = mode isa AbstractSafeBestNonlinearTerminationMode ?
8887
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
8988
if mode isa AbstractSafeNonlinearTerminationMode
90-
initial_objective = TT(0)
89+
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
90+
initial_objective = NONLINEARSOLVE_DEFAULT_NORM(du)
91+
else
92+
initial_objective = NONLINEARSOLVE_DEFAULT_NORM(du) /
93+
(NONLINEARSOLVE_DEFAULT_NORM(du .+ u) + eps(TT))
94+
end
9195
objectives_trace = Vector{TT}(undef, mode.patience_steps)
96+
best_value = initial_objective
9297
else
9398
initial_objective = nothing
9499
objectives_trace = nothing
100+
best_value = __cvt_real(T, Inf)
95101
end
96102
return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode),
97103
typeof(initial_objective), typeof(objectives_trace)}(u_,
@@ -122,6 +128,13 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
122128
criteria = cache.reltol
123129
end
124130

131+
# Protective Break
132+
if isinf(objective) || isnan(objective) ||
133+
(objective cache.initial_objective * cache.mode.protective_threshold * length(du))
134+
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
135+
return true
136+
end
137+
125138
# Check if best solution
126139
if mode isa AbstractSafeBestNonlinearTerminationMode &&
127140
objective < cache.best_objective_value
@@ -154,12 +167,6 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
154167
end
155168
end
156169

157-
# Protective Break
158-
if objective cache.initial_objective * cache.mode.protective_threshold * length(du)
159-
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
160-
return true
161-
end
162-
163170
cache.retcode = NonlinearSafeTerminationReturnCode.Failure
164171
return false
165172
end
@@ -238,9 +245,10 @@ function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64,
238245
best_objective_value_iteration = 0,
239246
return_code = NLSolveSafeTerminationReturnCode.Failure)
240247
u = u !== nothing ? copy(u) : u
248+
Base.depwarn("NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!",
249+
:NLSolveSafeTerminationResult)
241250
return NLSolveSafeTerminationResult{typeof(best_objective_value), typeof(u)}(u,
242-
best_objective_value,
243-
best_objective_value_iteration, return_code)
251+
best_objective_value, best_objective_value_iteration, return_code)
244252
end
245253

246254
const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
@@ -296,6 +304,8 @@ Define the termination criteria for the NonlinearProblem or SteadyStateProblem.
296304
* `protective_threshold`: If the objective value increased by this factor wrt initial objective terminate immediately.
297305
* `patience_steps`: If objective is within `patience_objective_multiplier` factor of the criteria and no improvement within `min_max_factor` has happened then terminate.
298306
307+
!!! warning
308+
This has been deprecated and will be removed in the next major release. Please use the new dispatch based termination conditions API.
299309
"""
300310
struct NLSolveTerminationCondition{mode, T,
301311
S <: Union{<:NLSolveSafeTerminationOptions, Nothing}}
@@ -323,6 +333,8 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
323333
protective_threshold = 1e3, patience_steps::Int = 30,
324334
patience_objective_multiplier = 3,
325335
min_max_factor = 1.3) where {T}
336+
Base.depwarn("NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!",
337+
:NLSolveTerminationCondition)
326338
@assert mode instances(NLSolveTerminationMode.T)
327339
options = if mode SAFE_TERMINATION_MODES
328340
NLSolveSafeTerminationOptions(protective_threshold, patience_steps,

0 commit comments

Comments
 (0)