Skip to content

Commit aa13e0c

Browse files
mhaurutorfjeldeyebai
authored
Optimization improvements (#2221)
* initial work on interface * Improving the Optimization.jl interface, work in progress * More work on Optimization.jl, still in progress * Add docstrings to Optimisation.jl * Fix OptimizationOptimJL version constraint * Clean up optimisation TODO notes * Relax OptimizationOptimJL version constraints * Simplify optimization imports * Remove commented out code * Small improvements all over in optimisation * Clean up of Optimisation tests * Add a test for OptimizationBBO * Add tests using OptimizationNLopt * Rename/move the optimisation test files The files for Optimisaton.jl and OptimInterface.jl were in the wrong folders: One in `test/optimisation` the other in `test/ext`, but the wrong way around. * Relax compat bounds on OptimizationBBO and OptimizationNLopt * Split a testset to test/optimisation/OptimisationCore.jl * Import AbstractADType from ADTypes, not SciMLBase * Fix Optimization.jl depwarning * Fix seeds in more tests * Merge OptimizationCore into Optimization * In optimisation, rename init_value to initial_params * Optimisation docstring improvements * Code style adjustments in optimisation * Qualify references in optimisation * Simplify creation of ModeResults * Qualified references in optimization tests * Enforce line length in optimization * Simplify optimisation exports * Enforce line legth in Optim.jl interface * Refactor away ModeEstimationProblem * Style and docstring improvements for optimisation * Add := test to optimisation tests. * Clarify comment * Simplify generate_initial_params * Fix doc references * Rename testsets * Refactor check_success * Make initial_params a kwarg * Remove unnecessary type constrain on kwarg * Fix broken reference in tests * Fix bug in generate_initial_params * Fix qualified references in optimisation tests * Add hasstats checks to optimisation tests * Extend OptimizationOptimJL compat to 0.3 Co-authored-by: Hong Ge <[email protected]> * Change some `import`s to `using` Co-authored-by: Tor Erlend Fjelde <[email protected]> * Change <keyword arguments> to kwargs... in docstrings * Add a two-argument method to OptimLogDensity as callable --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent e9fddbe commit aa13e0c

File tree

9 files changed

+1024
-549
lines changed

9 files changed

+1024
-549
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2525
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2626
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2727
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
28+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
29+
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
2830
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2931
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
3032
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -68,6 +70,8 @@ LogDensityProblems = "2"
6870
LogDensityProblemsAD = "1.7.0"
6971
MCMCChains = "5, 6"
7072
NamedArrays = "0.9, 0.10"
73+
Optimization = "3"
74+
OptimizationOptimJL = "0.1, 0.2, 0.3"
7175
OrderedCollections = "1"
7276
Optim = "1"
7377
Reexport = "0.2, 1"

ext/TuringOptimExt.jl

Lines changed: 67 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -2,105 +2,22 @@ module TuringOptimExt
22

33
if isdefined(Base, :get_extension)
44
import Turing
5-
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
5+
import Turing:
6+
DynamicPPL,
7+
NamedArrays,
8+
Accessors,
9+
Optimisation
610
import Optim
711
else
812
import ..Turing
9-
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
13+
import ..Turing:
14+
DynamicPPL,
15+
NamedArrays,
16+
Accessors,
17+
Optimisation
1018
import ..Optim
1119
end
1220

13-
"""
14-
ModeResult{
15-
V<:NamedArrays.NamedArray,
16-
M<:NamedArrays.NamedArray,
17-
O<:Optim.MultivariateOptimizationResults,
18-
S<:NamedArrays.NamedArray
19-
}
20-
21-
A wrapper struct to store various results from a MAP or MLE estimation.
22-
"""
23-
struct ModeResult{
24-
V<:NamedArrays.NamedArray,
25-
O<:Optim.MultivariateOptimizationResults,
26-
M<:Turing.OptimLogDensity
27-
} <: StatsBase.StatisticalModel
28-
"A vector with the resulting point estimates."
29-
values::V
30-
"The stored Optim.jl results."
31-
optim_result::O
32-
"The final log likelihood or log joint, depending on whether `MAP` or `MLE` was run."
33-
lp::Float64
34-
"The evaluation function used to calculate the output."
35-
f::M
36-
end
37-
#############################
38-
# Various StatsBase methods #
39-
#############################
40-
41-
42-
43-
function Base.show(io::IO, ::MIME"text/plain", m::ModeResult)
44-
print(io, "ModeResult with maximized lp of ")
45-
Printf.@printf(io, "%.2f", m.lp)
46-
println(io)
47-
show(io, m.values)
48-
end
49-
50-
function Base.show(io::IO, m::ModeResult)
51-
show(io, m.values.array)
52-
end
53-
54-
function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
55-
# Get columns for coeftable.
56-
terms = string.(StatsBase.coefnames(m))
57-
estimates = m.values.array[:, 1]
58-
stderrors = StatsBase.stderror(m)
59-
zscore = estimates ./ stderrors
60-
p = map(z -> StatsAPI.pvalue(Distributions.Normal(), z; tail=:both), zscore)
61-
62-
# Confidence interval (CI)
63-
q = Statistics.quantile(Distributions.Normal(), (1 + level) / 2)
64-
ci_low = estimates .- q .* stderrors
65-
ci_high = estimates .+ q .* stderrors
66-
67-
level_ = 100*level
68-
level_percentage = isinteger(level_) ? Int(level_) : level_
69-
70-
StatsBase.CoefTable(
71-
[estimates, stderrors, zscore, p, ci_low, ci_high],
72-
["Coef.", "Std. Error", "z", "Pr(>|z|)", "Lower $(level_percentage)%", "Upper $(level_percentage)%"],
73-
terms)
74-
end
75-
76-
function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff.hessian, kwargs...)
77-
# Calculate Hessian and information matrix.
78-
79-
# Convert the values to their unconstrained states to make sure the
80-
# Hessian is computed with respect to the untransformed parameters.
81-
linked = DynamicPPL.istrans(m.f.varinfo)
82-
if linked
83-
m = Accessors.@set m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
84-
end
85-
86-
# Calculate the Hessian, which is the information matrix because the negative of the log likelihood was optimized
87-
varnames = StatsBase.coefnames(m)
88-
info = hessian_function(m.f, m.values.array[:, 1])
89-
90-
# Link it back if we invlinked it.
91-
if linked
92-
m = Accessors.@set m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
93-
end
94-
95-
return NamedArrays.NamedArray(info, (varnames, varnames))
96-
end
97-
98-
StatsBase.coef(m::ModeResult) = m.values
99-
StatsBase.coefnames(m::ModeResult) = names(m.values)[1]
100-
StatsBase.params(m::ModeResult) = StatsBase.coefnames(m)
101-
StatsBase.vcov(m::ModeResult) = inv(StatsBase.informationmatrix(m))
102-
StatsBase.loglikelihood(m::ModeResult) = m.lp
103-
10421
####################
10522
# Optim.jl methods #
10623
####################
@@ -125,26 +42,41 @@ mle = optimize(model, MLE())
12542
mle = optimize(model, MLE(), NelderMead())
12643
```
12744
"""
128-
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, options::Optim.Options=Optim.Options(); kwargs...)
129-
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
130-
f = Turing.OptimLogDensity(model, ctx)
45+
function Optim.optimize(
46+
model::DynamicPPL.Model, ::Optimisation.MLE, options::Optim.Options=Optim.Options();
47+
kwargs...
48+
)
49+
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
50+
f = Optimisation.OptimLogDensity(model, ctx)
13151
init_vals = DynamicPPL.getparams(f)
13252
optimizer = Optim.LBFGS()
13353
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
13454
end
135-
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
55+
function Optim.optimize(
56+
model::DynamicPPL.Model,
57+
::Optimisation.MLE,
58+
init_vals::AbstractArray,
59+
options::Optim.Options=Optim.Options();
60+
kwargs...
61+
)
13662
optimizer = Optim.LBFGS()
13763
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
13864
end
139-
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
140-
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
141-
f = Turing.OptimLogDensity(model, ctx)
65+
function Optim.optimize(
66+
model::DynamicPPL.Model,
67+
::Optimisation.MLE,
68+
optimizer::Optim.AbstractOptimizer,
69+
options::Optim.Options=Optim.Options();
70+
kwargs...
71+
)
72+
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
73+
f = Optimisation.OptimLogDensity(model, ctx)
14274
init_vals = DynamicPPL.getparams(f)
14375
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
14476
end
14577
function Optim.optimize(
14678
model::DynamicPPL.Model,
147-
::Turing.MLE,
79+
::Optimisation.MLE,
14880
init_vals::AbstractArray,
14981
optimizer::Optim.AbstractOptimizer,
15082
options::Optim.Options=Optim.Options();
@@ -154,8 +86,8 @@ function Optim.optimize(
15486
end
15587

15688
function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
157-
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
158-
return _optimize(model, Turing.OptimLogDensity(model, ctx), args...; kwargs...)
89+
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
90+
return _optimize(model, Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
15991
end
16092

16193
"""
@@ -178,26 +110,41 @@ map_est = optimize(model, MAP())
178110
map_est = optimize(model, MAP(), NelderMead())
179111
```
180112
"""
181-
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, options::Optim.Options=Optim.Options(); kwargs...)
182-
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
183-
f = Turing.OptimLogDensity(model, ctx)
113+
function Optim.optimize(
114+
model::DynamicPPL.Model, ::Optimisation.MAP, options::Optim.Options=Optim.Options();
115+
kwargs...
116+
)
117+
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
118+
f = Optimisation.OptimLogDensity(model, ctx)
184119
init_vals = DynamicPPL.getparams(f)
185120
optimizer = Optim.LBFGS()
186121
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
187122
end
188-
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
123+
function Optim.optimize(
124+
model::DynamicPPL.Model,
125+
::Optimisation.MAP,
126+
init_vals::AbstractArray,
127+
options::Optim.Options=Optim.Options();
128+
kwargs...
129+
)
189130
optimizer = Optim.LBFGS()
190131
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
191132
end
192-
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
193-
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
194-
f = Turing.OptimLogDensity(model, ctx)
133+
function Optim.optimize(
134+
model::DynamicPPL.Model,
135+
::Optimisation.MAP,
136+
optimizer::Optim.AbstractOptimizer,
137+
options::Optim.Options=Optim.Options();
138+
kwargs...
139+
)
140+
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
141+
f = Optimisation.OptimLogDensity(model, ctx)
195142
init_vals = DynamicPPL.getparams(f)
196143
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
197144
end
198145
function Optim.optimize(
199146
model::DynamicPPL.Model,
200-
::Turing.MAP,
147+
::Optimisation.MAP,
201148
init_vals::AbstractArray,
202149
optimizer::Optim.AbstractOptimizer,
203150
options::Optim.Options=Optim.Options();
@@ -207,8 +154,8 @@ function Optim.optimize(
207154
end
208155

209156
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
210-
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
211-
return _optimize(model, Turing.OptimLogDensity(model, ctx), args...; kwargs...)
157+
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
158+
return _optimize(model, Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
212159
end
213160

214161
"""
@@ -218,7 +165,7 @@ Estimate a mode, i.e., compute a MLE or MAP estimate.
218165
"""
219166
function _optimize(
220167
model::DynamicPPL.Model,
221-
f::Turing.OptimLogDensity,
168+
f::Optimisation.OptimLogDensity,
222169
init_vals::AbstractArray=DynamicPPL.getparams(f),
223170
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
224171
options::Optim.Options=Optim.Options(),
@@ -236,25 +183,19 @@ function _optimize(
236183

237184
# Warn the user if the optimization did not converge.
238185
if !Optim.converged(M)
239-
@warn "Optimization did not converge! You may need to correct your model or adjust the Optim parameters."
186+
@warn """
187+
Optimization did not converge! You may need to correct your model or adjust the
188+
Optim parameters.
189+
"""
240190
end
241191

242-
# Get the VarInfo at the MLE/MAP point, and run the model to ensure
243-
# correct dimensionality.
192+
# Get the optimum in unconstrained space. `getparams` does the invlinking.
244193
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
245-
f = Accessors.@set f.varinfo = DynamicPPL.invlink(f.varinfo, model)
246-
vals = DynamicPPL.getparams(f)
247-
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
248-
249-
# Make one transition to get the parameter names.
250194
vns_vals_iter = Turing.Inference.getparams(model, f.varinfo)
251195
varnames = map(Symbol first, vns_vals_iter)
252196
vals = map(last, vns_vals_iter)
253-
254-
# Store the parameters and their names in an array.
255197
vmat = NamedArrays.NamedArray(vals, varnames)
256-
257-
return ModeResult(vmat, M, -M.minimum, f)
198+
return Optimisation.ModeResult(vmat, M, -M.minimum, f)
258199
end
259200

260201
end # module

src/Turing.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,11 @@ export @model, # modelling
138138

139139
ordered, # Exports from Bijectors
140140

141-
constrained_space, # optimisation interface
141+
maximum_a_posteriori,
142+
maximum_likelihood,
143+
# The MAP and MLE exports are only needed for the Optim.jl interface.
142144
MAP,
143-
MLE,
144-
get_parameter_bounds,
145-
optim_objective,
146-
optim_function,
147-
optim_problem
145+
MLE
148146

149147
if !isdefined(Base, :get_extension)
150148
using Requires

0 commit comments

Comments
 (0)