-
Notifications
You must be signed in to change notification settings - Fork 233
Add iterator interface #745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
b10f27e
cc594a8
d6a6d75
2117364
b7b63e8
8a64efb
0826be0
b0e5c30
a22346a
2310160
15c8b62
75ece54
54a8542
bc42ad8
afaa7e9
cc5ebfc
56e0795
dc03af8
3daa277
35ffc80
e027fe9
ea17c1c
9faa883
7620e4f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,36 +27,76 @@ function initial_convergence(d, state, method::AbstractOptimizer, initial_x, opt | |
| end | ||
| initial_convergence(d, state, method::ZerothOrderOptimizer, initial_x, options) = false | ||
|
|
||
| function optimize(d::D, initial_x::Tx, method::M, | ||
| options::Options{T, TCallback} = Options(;default_options(method)...), | ||
| state = initial_state(method, options, d, initial_x)) where {D<:AbstractObjective, M<:AbstractOptimizer, Tx <: AbstractArray, T, TCallback} | ||
| if length(initial_x) == 1 && typeof(method) <: NelderMead | ||
| error("You cannot use NelderMead for univariate problems. Alternatively, use either interval bound univariate optimization, or another method such as BFGS or Newton.") | ||
| end | ||
| struct OptimIterator{D <: AbstractObjective, M <: AbstractOptimizer, Tx <: AbstractArray, O <: Options, S} | ||
| d::D | ||
| initial_x::Tx | ||
| method::M | ||
| options::O | ||
| state::S | ||
| end | ||
|
|
||
| Base.IteratorSize(::Type{<:OptimIterator}) = Base.SizeUnknown() | ||
| Base.IteratorEltype(::Type{<:OptimIterator}) = Base.HasEltype() | ||
| Base.eltype(::Type{<:OptimIterator}) = IteratorState | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a problem that the element type is non-concrete here? Could it be defined in a way that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems below |
||
|
|
||
| @with_kw struct IteratorState{IT <: OptimIterator, TR <: OptimizationTrace} | ||
pkofod marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Put `OptimIterator` in iterator state so that `OptimizationResults` can | ||
| # be constructed from `IteratorState`. | ||
| iter::IT | ||
|
||
|
|
||
| t0::Float64 | ||
| tr::TR | ||
| tracing::Bool | ||
| stopped::Bool | ||
| stopped_by_callback::Bool | ||
| stopped_by_time_limit::Bool | ||
| f_limit_reached::Bool | ||
| g_limit_reached::Bool | ||
| h_limit_reached::Bool | ||
| x_converged::Bool | ||
| f_converged::Bool | ||
| f_increased::Bool | ||
| counter_f_tol::Int | ||
| g_converged::Bool | ||
| converged::Bool | ||
| iteration::Int | ||
| ls_success::Bool | ||
| end | ||
|
|
||
| t0 = time() # Initial time stamp used to control early stopping by options.time_limit | ||
| function Base.iterate(iter::OptimIterator, istate = nothing) | ||
| @unpack d, initial_x, method, options, state = iter | ||
pkofod marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if istate === nothing | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO it would be cleaner to move this code to a separate |
||
| t0 = time() # Initial time stamp used to control early stopping by options.time_limit | ||
|
|
||
| tr = OptimizationTrace{typeof(value(d)), typeof(method)}() | ||
| tracing = options.store_trace || options.show_trace || options.extended_trace || options.callback != nothing | ||
| stopped, stopped_by_callback, stopped_by_time_limit = false, false, false | ||
| f_limit_reached, g_limit_reached, h_limit_reached = false, false, false | ||
| x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0 | ||
| tr = OptimizationTrace{typeof(value(d)), typeof(method)}() | ||
| tracing = options.store_trace || options.show_trace || options.extended_trace || options.callback != nothing | ||
| stopped, stopped_by_callback, stopped_by_time_limit = false, false, false | ||
| f_limit_reached, g_limit_reached, h_limit_reached = false, false, false | ||
| x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0 | ||
|
|
||
| g_converged = initial_convergence(d, state, method, initial_x, options) | ||
| converged = g_converged | ||
| g_converged = initial_convergence(d, state, method, initial_x, options) | ||
| converged = g_converged | ||
|
|
||
| # prepare iteration counter (used to make "initial state" trace entry) | ||
| iteration = 0 | ||
| # prepare iteration counter (used to make "initial state" trace entry) | ||
| iteration = 0 | ||
|
|
||
| options.show_trace && print_header(method) | ||
| trace!(tr, d, state, iteration, method, options, time()-t0) | ||
| ls_success::Bool = true | ||
|
|
||
| # Note: `optimize` depends on that first iteration always yields something | ||
| # (i.e., `iterate` does _not_ return a `nothing` when `istate === nothing`). | ||
| else | ||
| @unpack_IteratorState istate | ||
pkofod marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| !converged && !stopped && iteration < options.iterations || return nothing | ||
|
|
||
| options.show_trace && print_header(method) | ||
| trace!(tr, d, state, iteration, method, options, time()-t0) | ||
| ls_success::Bool = true | ||
| while !converged && !stopped && iteration < options.iterations | ||
| iteration += 1 | ||
|
|
||
| ls_failed = update_state!(d, state, method) | ||
| if !ls_success | ||
| break # it returns true if it's forced by something in update! to stop (eg dx_dg == 0.0 in BFGS, or linesearch errors) | ||
| # it returns true if it's forced by something in update! to stop (eg dx_dg == 0.0 in BFGS, or linesearch errors) | ||
| return nothing | ||
| end | ||
| update_g!(d, state, method) # TODO: Should this be `update_fg!`? | ||
|
|
||
|
|
@@ -85,7 +125,35 @@ function optimize(d::D, initial_x::Tx, method::M, | |
| stopped_by_time_limit || f_limit_reached || g_limit_reached || h_limit_reached | ||
| stopped = true | ||
| end | ||
| end # while | ||
| end | ||
|
|
||
| new_istate = IteratorState( | ||
| iter, | ||
| t0, | ||
| tr, | ||
| tracing, | ||
| stopped, | ||
| stopped_by_callback, | ||
| stopped_by_time_limit, | ||
| f_limit_reached, | ||
| g_limit_reached, | ||
| h_limit_reached, | ||
| x_converged, | ||
| f_converged, | ||
| f_increased, | ||
| counter_f_tol, | ||
| g_converged, | ||
| converged, | ||
| iteration, | ||
| ls_success, | ||
| ) | ||
|
|
||
| return new_istate, new_istate | ||
| end | ||
|
|
||
| function OptimizationResults(istate::IteratorState) | ||
| @unpack_IteratorState istate | ||
pkofod marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| @unpack d, initial_x, method, options, state = iter | ||
pkofod marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| after_while!(d, state, method, options) | ||
|
||
|
|
||
|
|
@@ -94,6 +162,9 @@ function optimize(d::D, initial_x::Tx, method::M, | |
| Tf = typeof(value(d)) | ||
| f_incr_pick = f_increased && !options.allow_f_increases | ||
|
|
||
| T = typeof(options.x_abstol) | ||
| Tx = typeof(initial_x) | ||
|
|
||
| return MultivariateOptimizationResults{typeof(method),T,Tx,typeof(x_abschange(state)),Tf,typeof(tr), Bool}(method, | ||
| initial_x, | ||
| pick_best_x(f_incr_pick, state), | ||
|
|
@@ -120,3 +191,34 @@ function optimize(d::D, initial_x::Tx, method::M, | |
| h_calls(d), | ||
| !ls_success) | ||
| end | ||
|
|
||
| function optimizing(d::D, initial_x::Tx, method::M, | ||
| options::Options = Options(;default_options(method)...), | ||
| state = initial_state(method, options, d, initial_x)) where {D<:AbstractObjective, M<:AbstractOptimizer, Tx <: AbstractArray} | ||
| if length(initial_x) == 1 && typeof(method) <: NelderMead | ||
| error("You cannot use NelderMead for univariate problems. Alternatively, use either interval bound univariate optimization, or another method such as BFGS or Newton.") | ||
| end | ||
| return OptimIterator(d, initial_x, method, options, state) | ||
| end | ||
|
|
||
| # Derive `IteratorState` accessors from `MultivariateOptimizationResults` accessors. | ||
| for f in [ | ||
| :(Base.summary) | ||
| :minimizer | ||
| :minimum | ||
| :iterations | ||
| :iteration_limit_reached | ||
| :trace | ||
| :x_trace | ||
| :f_trace | ||
| :f_calls | ||
| :converged | ||
| :g_norm_trace | ||
| :g_calls | ||
| :x_converged | ||
| :f_converged | ||
| :g_converged | ||
| :initial_state | ||
| ] | ||
| @eval $f(istate::IteratorState) = $f(OptimizationResults(istate)) | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think JET won't agree with this comment 😄
Generally, the code above seems a bit unfortunate... Maybe
optimizingshould return the iterator AND the initial state?I also wonder, is there no utility in Julia for directly obtaining the last state of an iterator?