Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/src/user/minimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,22 @@ line search errors if `initial_x` is a stationary point. Notice, that this is on
a first order check. If `initial_x` is any type of stationary point, `g_converged`
will be true. This includes local minima, saddle points, and local maxima. If `iterations` is `0`
and `g_converged` is `true`, the user needs to keep this point in mind.

## Iterator interface
For multivariable optimizations, iterator interface is provided through `Optim.optimizing`
function. Using this interface, `optimize(args...; kwargs...)` is equivalent to

```jl
let istate
for istate′ in Optim.optimizing(args...; kwargs...)
istate = istate′
end
Optim.OptimizationResults(istate)
end
```

The iterator returned by `Optim.optimizing` yields an iterator state for each iteration
step.

Functions that can be called on the result object (e.g. `minimizer`, `iterations`; see
[Complete list of functions](@ref)) can be used on the iteration state `istate`.
17 changes: 10 additions & 7 deletions src/api.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Base.summary(r::OptimizationResults) = summary(r.method) # might want to do more here than just return summary of the method used
Base.summary(r::Union{OptimizationResults, IteratorState}) =
summary(AbstractOptimizer(r)) # might want to do more here than just return summary of the method used
minimizer(r::OptimizationResults) = r.minimizer
minimum(r::OptimizationResults) = r.minimum
iterations(r::OptimizationResults) = r.iterations
Expand All @@ -9,6 +10,8 @@ trace(r::OptimizationResults) =
"No trace in optimization results. To get a trace, run optimize() with store_trace = true.",
)

AbstractOptimizer(r::OptimizationResults) = r.method

function x_trace(r::UnivariateOptimizationResults)
tr = trace(r)
!haskey(tr[1].metadata, "minimizer") && error(
Expand All @@ -35,7 +38,7 @@ end
x_upper_trace(r::MultivariateOptimizationResults) =
error("x_upper_trace is not implemented for $(summary(r)).")

function x_trace(r::MultivariateOptimizationResults)
function x_trace(r::Union{MultivariateOptimizationResults, IteratorState})
tr = trace(r)
if isa(r.method, NelderMead)
throw(
Expand All @@ -50,7 +53,7 @@ function x_trace(r::MultivariateOptimizationResults)
[state.metadata["x"] for state in tr]
end

function centroid_trace(r::MultivariateOptimizationResults)
function centroid_trace(r::Union{MultivariateOptimizationResults, IteratorState})
tr = trace(r)
Copy link
Contributor Author

@tkf tkf Sep 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose tr = trace(r) should be added here? I think it'll throw UndefVarError otherwise. There are two more places I did this change.

I'm including these changes (adding tr = trace(r)) although they are not directly related to PR.

if !isa(r.method, NelderMead)
throw(
Expand All @@ -64,7 +67,7 @@ function centroid_trace(r::MultivariateOptimizationResults)
)
[state.metadata["centroid"] for state in tr]
end
function simplex_trace(r::MultivariateOptimizationResults)
function simplex_trace(r::Union{MultivariateOptimizationResults, IteratorState})
tr = trace(r)
if !isa(r.method, NelderMead)
throw(
Expand All @@ -78,7 +81,7 @@ function simplex_trace(r::MultivariateOptimizationResults)
)
[state.metadata["simplex"] for state in tr]
end
function simplex_value_trace(r::MultivariateOptimizationResults)
function simplex_value_trace(r::Union{MultivariateOptimizationResults, IteratorState})
tr = trace(r)
if !isa(r.method, NelderMead)
throw(
Expand All @@ -94,10 +97,10 @@ function simplex_value_trace(r::MultivariateOptimizationResults)
end


f_trace(r::OptimizationResults) = [state.value for state in trace(r)]
f_trace(r::Union{OptimizationResults, IteratorState}) = [state.value for state in trace(r)]
g_norm_trace(r::OptimizationResults) =
error("g_norm_trace is not implemented for $(summary(r)).")
g_norm_trace(r::MultivariateOptimizationResults) = [state.g_norm for state in trace(r)]
g_norm_trace(r::Union{OptimizationResults, IteratorState}) = [state.g_norm for state in trace(r)]

f_calls(r::OptimizationResults) = r.f_calls
f_calls(d) = first(d.f_calls)
Expand Down
36 changes: 26 additions & 10 deletions src/multivariate/optimize/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function optimize(
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
f,
Expand All @@ -190,7 +190,7 @@ function optimize(
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
f,
Expand All @@ -208,10 +208,14 @@ function optimize(
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end

# no method supplied with objective
function optimizing(d::T, initial_x::AbstractArray, options::Options) where T<:AbstractObjective
optimizing(d, initial_x, fallback_method(d), options)
end
# no method supplied with inplace and autodiff keywords becauase objective is not supplied
function optimize(
d::T,
initial_x::AbstractArray,
Expand All @@ -229,7 +233,7 @@ function optimize(
)
method = fallback_method(f)
d = promote_objtype(method, initial_x, autodiff, inplace, f)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
f,
Expand All @@ -242,7 +246,7 @@ function optimize(

method = fallback_method(f, g)
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
f,
Expand All @@ -257,7 +261,7 @@ function optimize(
method = fallback_method(f, g, h)
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)

optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end

# potentially everything is supplied (besides caches)
Expand All @@ -270,8 +274,9 @@ function optimize(
autodiff = :finite,
)


d = promote_objtype(method, initial_x, autodiff, inplace, f)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
f,
Expand All @@ -298,7 +303,7 @@ function optimize(

d = promote_objtype(method, initial_x, autodiff, inplace, f, g)

optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
f,
Expand All @@ -313,7 +318,7 @@ function optimize(

d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)

optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end

function optimize(
Expand All @@ -324,6 +329,17 @@ function optimize(
autodiff = :finite,
inplace = true,
) where {D<:Union{NonDifferentiable,OnceDifferentiable}}

d = promote_objtype(method, initial_x, autodiff, inplace, d)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end

function optimize(args...; kwargs...)
local istate
for istate′ in optimizing(args...; kwargs...)
istate = istate′
end
# We can safely assume that `istate` is defined at this point. That is to say,
# `OptimIterator` guarantees that `iterate(::OptimIterator) !== nothing`.
Comment on lines +334 to +335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think JET won't agree with this comment 😄

Generally, the code above seems a bit unfortunate... Maybe optimizing should return the iterator AND the initial state?

I also wonder, is there no utility in Julia for directly obtaining the last state of an iterator?

return OptimizationResults(istate)
end
Loading
Loading