Skip to content

Commit fe46018

Browse files
committed
update Optimisation code to not use LogDensityProblemsAD
1 parent b21bc4b commit fe46018

File tree

4 files changed

+88
-69
lines changed

4 files changed

+88
-69
lines changed

ext/TuringOptimExt.jl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
module TuringOptimExt
22

3-
if isdefined(Base, :get_extension)
4-
using Turing: Turing
5-
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
6-
using Optim: Optim
7-
else
8-
import ..Turing
9-
import ..Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
10-
import ..Optim
11-
end
3+
using Turing: Turing
4+
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
5+
using Optim: Optim
126

137
####################
148
# Optim.jl methods #
@@ -42,7 +36,7 @@ function Optim.optimize(
4236
)
4337
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
4438
f = Optimisation.OptimLogDensity(model, ctx)
45-
init_vals = DynamicPPL.getparams(f)
39+
init_vals = DynamicPPL.getparams(f.ldf)
4640
optimizer = Optim.LBFGS()
4741
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
4842
end
@@ -65,7 +59,7 @@ function Optim.optimize(
6559
)
6660
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
6761
f = Optimisation.OptimLogDensity(model, ctx)
68-
init_vals = DynamicPPL.getparams(f)
62+
init_vals = DynamicPPL.getparams(f.ldf)
6963
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
7064
end
7165
function Optim.optimize(
@@ -112,7 +106,7 @@ function Optim.optimize(
112106
)
113107
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
114108
f = Optimisation.OptimLogDensity(model, ctx)
115-
init_vals = DynamicPPL.getparams(f)
109+
init_vals = DynamicPPL.getparams(f.ldf)
116110
optimizer = Optim.LBFGS()
117111
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
118112
end
@@ -135,7 +129,7 @@ function Optim.optimize(
135129
)
136130
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
137131
f = Optimisation.OptimLogDensity(model, ctx)
138-
init_vals = DynamicPPL.getparams(f)
132+
init_vals = DynamicPPL.getparams(f.ldf)
139133
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
140134
end
141135
function Optim.optimize(
@@ -162,17 +156,20 @@ Estimate a mode, i.e., compute a MLE or MAP estimate.
162156
function _optimize(
163157
model::DynamicPPL.Model,
164158
f::Optimisation.OptimLogDensity,
165-
init_vals::AbstractArray=DynamicPPL.getparams(f),
159+
init_vals::AbstractArray=DynamicPPL.getparams(f.ldf),
166160
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
167161
options::Optim.Options=Optim.Options(),
168162
args...;
169163
kwargs...,
170164
)
171165
# Convert the initial values, since it is assumed that users provide them
172166
# in the constrained space.
173-
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
174-
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
175-
init_vals = DynamicPPL.getparams(f)
167+
# TODO(penelopeysm): As with in src/optimisation/Optimisation.jl, unclear
168+
# whether initialisation is really necessary at all
169+
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
170+
vi = DynamicPPL.link(vi, f.ldf.model)
171+
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
172+
init_vals = DynamicPPL.getparams(f.ldf)
176173

177174
# Optimize!
178175
M = Optim.optimize(Optim.only_fg!(f), init_vals, optimizer, options, args...; kwargs...)
@@ -186,12 +183,16 @@ function _optimize(
186183
end
187184

188185
# Get the optimum in unconstrained space. `getparams` does the invlinking.
189-
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
190-
vns_vals_iter = Turing.Inference.getparams(model, f.varinfo)
186+
vi = f.ldf.varinfo
187+
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
188+
logdensity_optimum = Optimisation.OptimLogDensity(
189+
f.ldf.model, vi_optimum, f.ldf.context
190+
)
191+
vns_vals_iter = Turing.Inference.getparams(model, vi_optimum)
191192
varnames = map(Symbol first, vns_vals_iter)
192193
vals = map(last, vns_vals_iter)
193194
vmat = NamedArrays.NamedArray(vals, varnames)
194-
return Optimisation.ModeResult(vmat, M, -M.minimum, f)
195+
return Optimisation.ModeResult(vmat, M, -M.minimum, logdensity_optimum)
195196
end
196197

197198
end # module

src/optimisation/Optimisation.jl

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ using ..Turing
44
using NamedArrays: NamedArrays
55
using DynamicPPL: DynamicPPL
66
using LogDensityProblems: LogDensityProblems
7-
using LogDensityProblemsAD: LogDensityProblemsAD
87
using Optimization: Optimization
98
using OptimizationOptimJL: OptimizationOptimJL
109
using Random: Random
@@ -95,22 +94,40 @@ function DynamicPPL.tilde_observe(
9594
end
9695

9796
"""
98-
OptimLogDensity{M<:DynamicPPL.Model,C<:Context,V<:DynamicPPL.VarInfo}
97+
OptimLogDensity{M<:DynamicPPL.Model,V<:DynamicPPL.VarInfo,C<:OptimizationContext,AD<:ADTypes.AbstractADType}
9998
10099
A struct that stores the negative log density function of a `DynamicPPL` model.
100+
101+
TODO(penelopeysm): It _doesn't_ really store the negative, does it? It's more like we
102+
overrode logdensity to give the negative logdensity.
101103
"""
102-
const OptimLogDensity{M<:DynamicPPL.Model,C<:OptimizationContext,V<:DynamicPPL.VarInfo,AD} = Turing.LogDensityFunction{
103-
M,V,C,AD
104+
struct OptimLogDensity{
105+
M<:DynamicPPL.Model,
106+
V<:DynamicPPL.VarInfo,
107+
C<:OptimizationContext,
108+
AD<:ADTypes.AbstractADType,
104109
}
110+
ldf::Turing.LogDensityFunction{M,V,C,AD}
111+
end
105112

106-
"""
107-
OptimLogDensity(model::DynamicPPL.Model, context::OptimizationContext)
113+
function OptimLogDensity(
114+
model::DynamicPPL.Model,
115+
vi::DynamicPPL.VarInfo,
116+
ctx::OptimizationContext;
117+
adtype::Union{Nothing,ADTypes.AbstractADType}=AutoForwardDiff(),
118+
)
119+
return OptimLogDensity(Turing.LogDensityFunction(model, vi, ctx; adtype=adtype))
120+
end
108121

109-
Create a callable `OptimLogDensity` struct that evaluates a model using the given `context`.
110-
"""
111-
function OptimLogDensity(model::DynamicPPL.Model, context::OptimizationContext)
112-
init = DynamicPPL.VarInfo(model)
113-
return Turing.LogDensityFunction(model, init, context)
122+
# No varinfo
123+
function OptimLogDensity(
124+
model::DynamicPPL.Model,
125+
ctx::OptimizationContext;
126+
adtype::Union{Nothing,ADTypes.AbstractADType}=AutoForwardDiff(),
127+
)
128+
return OptimLogDensity(
129+
Turing.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx; adtype=adtype)
130+
)
114131
end
115132

116133
"""
@@ -123,40 +140,30 @@ depends on the context of `f`.
123140
Any second argument is ignored. The two-argument method only exists to match interface the
124141
required by Optimization.jl.
125142
"""
126-
function (f::OptimLogDensity)(z::AbstractVector)
127-
varinfo = DynamicPPL.unflatten(f.varinfo, z)
128-
return -DynamicPPL.getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, f.context)))
129-
end
130-
143+
(f::OptimLogDensity)(z::AbstractVector) = -LogDensityProblems.logdensity(f.ldf, z)
131144
(f::OptimLogDensity)(z, _) = f(z)
132145

133-
# NOTE: This seems a bit weird IMO since this is the _negative_ log-likelihood.
134-
LogDensityProblems.logdensity(f::OptimLogDensity, z::AbstractVector) = f(z)
135-
136146
# NOTE: The format of this function is dictated by Optim. The first argument sets whether to
137147
# compute the function value, the second whether to compute the gradient (and stores the
138148
# gradient). The last one is the actual argument of the objective function.
139149
function (f::OptimLogDensity)(F, G, z)
140150
if G !== nothing
141-
# Calculate negative log joint and its gradient.
142-
# TODO: Make OptimLogDensity already an LogDensityProblems.ADgradient? Allow to
143-
# specify AD?
144-
= LogDensityProblemsAD.ADgradient(f)
145-
neglogp, ∇neglogp = LogDensityProblems.logdensity_and_gradient(ℓ, z)
151+
# Calculate log joint and its gradient.
152+
logp, ∇logp = LogDensityProblems.logdensity_and_gradient(f.ldf, z)
146153

147-
# Save the gradient to the pre-allocated array.
148-
copyto!(G, ∇neglogp)
154+
# Save the negative gradient to the pre-allocated array.
155+
copyto!(G, -∇logp)
149156

150157
# If F is something, the negative log joint is requested as well.
151158
# We have already computed it as a by-product above and hence return it directly.
152159
if F !== nothing
153-
return neglogp
160+
return -logp
154161
end
155162
end
156163

157164
# Only negative log joint requested but no gradient.
158165
if F !== nothing
159-
return LogDensityProblems.logdensity(f, z)
166+
return -LogDensityProblems.logdensity(f.ldf, z)
160167
end
161168

162169
return nothing
@@ -232,9 +239,11 @@ function StatsBase.informationmatrix(
232239

233240
# Convert the values to their unconstrained states to make sure the
234241
# Hessian is computed with respect to the untransformed parameters.
235-
linked = DynamicPPL.istrans(m.f.varinfo)
242+
linked = DynamicPPL.istrans(m.f.ldf.varinfo)
236243
if linked
237-
m = Accessors.@set m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
244+
new_vi = DynamicPPL.invlink!!(m.f.ldf.varinfo, m.f.ldf.model)
245+
new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context)
246+
m = Accessors.@set m.f = new_f
238247
end
239248

240249
# Calculate the Hessian, which is the information matrix because the negative of the log
@@ -244,7 +253,9 @@ function StatsBase.informationmatrix(
244253

245254
# Link it back if we invlinked it.
246255
if linked
247-
m = Accessors.@set m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
256+
new_vi = DynamicPPL.link!!(m.f.ldf.varinfo, m.f.ldf.model)
257+
new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context)
258+
m = Accessors.@set m.f = new_f
248259
end
249260

250261
return NamedArrays.NamedArray(info, (varnames, varnames))
@@ -265,7 +276,7 @@ Return the values of all the variables with the symbol(s) `var_symbol` in the mo
265276
argument should be either a `Symbol` or a vector of `Symbol`s.
266277
"""
267278
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
268-
log_density = m.f
279+
log_density = m.f.ldf
269280
# Get all the variable names in the model. This is the same as the list of keys in
270281
# m.values, but they are more convenient to filter when they are VarNames rather than
271282
# Symbols.
@@ -297,9 +308,9 @@ richer format of `ModeResult`. It also takes care of transforming them back to t
297308
parameter space in case the optimization was done in a transformed space.
298309
"""
299310
function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
300-
varinfo_new = DynamicPPL.unflatten(log_density.varinfo, solution.u)
311+
varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u)
301312
# `getparams` performs invlinking if needed
302-
vns_vals_iter = Turing.Inference.getparams(log_density.model, varinfo_new)
313+
vns_vals_iter = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
303314
syms = map(Symbol first, vns_vals_iter)
304315
vals = map(last, vns_vals_iter)
305316
return ModeResult(
@@ -383,12 +394,15 @@ end
383394
OptimizationProblem(log_density::OptimLogDensity, adtype, constraints)
384395
385396
Create an `OptimizationProblem` for the objective function defined by `log_density`.
397+
398+
Note that the adtype parameter here overrides any adtype parameter the
399+
OptimLogDensity was constructed with.
386400
"""
387401
function Optimization.OptimizationProblem(log_density::OptimLogDensity, adtype, constraints)
388402
# Note that OptimLogDensity is a callable that evaluates the model with given
389403
# parameters. Hence we can use it in the objective function as below.
390404
f = Optimization.OptimizationFunction(log_density, adtype; cons=constraints.cons)
391-
initial_params = log_density.varinfo[:]
405+
initial_params = log_density.ldf.varinfo[:]
392406
prob = if !has_constraints(constraints)
393407
Optimization.OptimizationProblem(f, initial_params)
394408
else
@@ -454,28 +468,34 @@ function estimate_mode(
454468
end
455469

456470
# Create an OptimLogDensity object that can be used to evaluate the objective function,
457-
# i.e. the negative log density. Set its VarInfo to the initial parameters.
458-
log_density = let
459-
inner_context = if estimator isa MAP
460-
DynamicPPL.DefaultContext()
461-
else
462-
DynamicPPL.LikelihoodContext()
463-
end
464-
ctx = OptimizationContext(inner_context)
465-
ld = OptimLogDensity(model, ctx)
466-
Accessors.@set ld.varinfo = DynamicPPL.unflatten(ld.varinfo, initial_params)
471+
# i.e. the negative log density.
472+
inner_context = if estimator isa MAP
473+
DynamicPPL.DefaultContext()
474+
else
475+
DynamicPPL.LikelihoodContext()
467476
end
477+
ctx = OptimizationContext(inner_context)
468478

479+
# Set its VarInfo to the initial parameters.
480+
# TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated
481+
# (using `LogDensityProblems.logdensity(ldf, x)`) the parameters in the
482+
# varinfo are completely ignored. The parameters only matter if you are calling evaluate!!
483+
# directly on the fields of the LogDensityFunction
484+
vi = DynamicPPL.VarInfo(model)
485+
vi = DynamicPPL.unflatten(vi, initial_params)
486+
487+
# Link the varinfo if needed.
469488
# TODO(mhauru) We currently couple together the questions of whether the user specified
470489
# bounds/constraints and whether we transform the objective function to an
471490
# unconstrained space. These should be separate concerns, but for that we need to
472491
# implement getting the bounds of the prior distributions.
473492
optimise_in_unconstrained_space = !has_constraints(constraints)
474493
if optimise_in_unconstrained_space
475-
transformed_varinfo = DynamicPPL.link(log_density.varinfo, log_density.model)
476-
log_density = Accessors.@set log_density.varinfo = transformed_varinfo
494+
vi = DynamicPPL.link(vi, model)
477495
end
478496

497+
log_density = OptimLogDensity(model, vi, ctx)
498+
479499
prob = Optimization.OptimizationProblem(log_density, adtype, constraints)
480500
solution = Optimization.solve(prob, solver; kwargs...)
481501
# TODO(mhauru) We return a ModeResult for compatibility with the older Optim.jl

test/ext/OptimInterface.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ using Turing
143143
DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal,
144144
DynamicPPL.TestUtils.demo_dot_assume_observe_submodel,
145145
DynamicPPL.TestUtils.demo_dot_assume_observe_matrix_index,
146-
DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix,
147146
DynamicPPL.TestUtils.demo_assume_submodel_observe_index_literal,
148147
DynamicPPL.TestUtils.demo_dot_assume_observe_index,
149148
DynamicPPL.TestUtils.demo_dot_assume_observe_index_literal,

test/optimisation/Optimisation.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,6 @@ using Turing
545545
DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal,
546546
DynamicPPL.TestUtils.demo_dot_assume_observe_submodel,
547547
DynamicPPL.TestUtils.demo_dot_assume_observe_matrix_index,
548-
DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix,
549548
DynamicPPL.TestUtils.demo_assume_submodel_observe_index_literal,
550549
DynamicPPL.TestUtils.demo_dot_assume_observe_index,
551550
DynamicPPL.TestUtils.demo_dot_assume_observe_index_literal,

0 commit comments

Comments
 (0)