Skip to content

Commit 81b506e

Browse files
committed
Add accessor functions
1 parent d6a6d75 commit 81b506e

File tree

4 files changed

+77
-3
lines changed

4 files changed

+77
-3
lines changed

docs/src/user/minimization.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,22 @@ line search errors if `initial_x` is a stationary point. Notice, that this is on
219219
a first order check. If `initial_x` is any type of stationary point, `g_converged`
220220
will be true. This includes local minima, saddle points, and local maxima. If `iterations` is `0`
221221
and `g_converged` is `true`, the user needs to keep this point in mind.
222+
223+
## Iterator interface
224+
For multivariable optimizations, iterator interface is provided through `Optim.optimizing`
225+
function. Using this interface, `optimize(args...; kwargs...)` is equivalent to
226+
227+
```jl
228+
let istate
229+
for istate′ in Optim.optimizing(args...; kwargs...)
230+
istate = istate′
231+
end
232+
Optim.OptimizationResults(istate)
233+
end
234+
```
235+
236+
The iterator returned by `Optim.optimizing` yields an iterator state for each iteration
237+
step.
238+
239+
Functions that can be called on the result object (e.g. `minimizer`, `iterations`; see
240+
[Complete list of functions](@ref)) can be used on the iteration state `istate`.

src/api.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,34 @@ rel_tol(r::OptimizationResults) = error("rel_tol is not implemented for $(summar
118118
rel_tol(r::UnivariateOptimizationResults) = r.rel_tol
119119
abs_tol(r::OptimizationResults) = error("abs_tol is not implemented for $(summary(r)).")
120120
abs_tol(r::UnivariateOptimizationResults) = r.abs_tol
121+
122+
123+
# Derive `IteratorState` accessors from `MultivariateOptimizationResults` accessors.
124+
125+
# Result accessors that does _not_ need to run `after_while!`:
126+
for f in [
127+
:(Base.summary)
128+
:iterations
129+
:iteration_limit_reached
130+
:trace
131+
:x_trace
132+
:f_trace
133+
:f_calls
134+
:converged
135+
:g_norm_trace
136+
:g_calls
137+
:x_converged
138+
:f_converged
139+
:g_converged
140+
:initial_state
141+
]
142+
@eval $f(istate::IteratorState) = $f(_OptimizationResults(istate))
143+
end
144+
145+
# Result accessors that need to run `after_while!`:
146+
for f in [
147+
:minimizer
148+
:minimum
149+
]
150+
@eval $f(istate::IteratorState) = $f(OptimizationResults(istate))
151+
end

src/multivariate/optimize/optimize.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,15 @@ function Base.iterate(iter::OptimIterator, istate = nothing)
152152
end
153153

154154
function OptimizationResults(istate::IteratorState)
155+
@unpack d, method, options, state = istate.iter
156+
after_while!(d, state, method, options)
157+
return _OptimizationResults(istate)
158+
end
159+
160+
function _OptimizationResults(istate::IteratorState)
155161
@unpack_IteratorState istate
156162
@unpack d, initial_x, method, options, state = iter
157163

158-
after_while!(d, state, method, options)
159-
160164
# we can just check minimum, as we've earlier enforced same types/eltypes
161165
# in variables besides the option settings
162166
Tf = typeof(value(d))

test/general/api.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,29 @@
146146
@test haskey(Optim.trace(res_extended_nm)[1].metadata,"step_type")
147147

148148
local istate
149-
for istate′ in Optim.optimizing(f, initial_x, BFGS())
149+
for istate′ in Optim.optimizing(f, initial_x, BFGS(),
150+
Optim.Options(extended_trace = true,
151+
store_trace = true))
150152
istate = istate′
153+
break
151154
end
155+
# (smoke) tests for accessor functions:
156+
@test summary(istate) == "BFGS"
157+
@test Optim.minimizer(istate) isa Vector{Float64}
158+
@test Optim.minimum(istate) isa Float64
159+
@test Optim.iterations(istate) == 0
160+
@test Optim.iteration_limit_reached(istate) == false
161+
@test Optim.trace(istate) isa Vector{<:Optim.OptimizationState}
162+
@test Optim.x_trace(istate) isa Vector{Vector{Float64}}
163+
@test Optim.f_trace(istate) isa Vector{Float64}
164+
@test Optim.f_calls(istate) == 1
165+
@test Optim.converged(istate) == false
166+
@test Optim.g_norm_trace(istate) isa Vector{Float64}
167+
@test Optim.g_calls(istate) == 1
168+
@test Optim.x_converged(istate) == false
169+
@test Optim.f_converged(istate) == false
170+
@test Optim.g_converged(istate) == false
171+
@test Optim.initial_state(istate) == initial_x
152172
@test Optim.OptimizationResults(istate) isa Optim.MultivariateOptimizationResults
153173
end
154174

0 commit comments

Comments
 (0)