@@ -4,7 +4,6 @@ using ..Turing
4
4
using NamedArrays: NamedArrays
5
5
using DynamicPPL: DynamicPPL
6
6
using LogDensityProblems: LogDensityProblems
7
- using LogDensityProblemsAD: LogDensityProblemsAD
8
7
using Optimization: Optimization
9
8
using OptimizationOptimJL: OptimizationOptimJL
10
9
using Random: Random
@@ -95,22 +94,40 @@ function DynamicPPL.tilde_observe(
95
94
end
96
95
97
96
"""
98
- OptimLogDensity{M<:DynamicPPL.Model,C<:Context,V<:DynamicPPL.VarInfo }
97
+ OptimLogDensity{M<:DynamicPPL.Model,V<:DynamicPPL.VarInfo, C<:OptimizationContext,AD<:ADTypes.AbstractADType }
99
98
100
99
A struct that stores the negative log density function of a `DynamicPPL` model.
100
+
101
+ TODO(penelopeysm): It _doesn't_ really store the negative, does it? It's more like we
102
+ overrode logdensity to give the negative logdensity.
101
103
"""
102
- const OptimLogDensity{M<: DynamicPPL.Model ,C<: OptimizationContext ,V<: DynamicPPL.VarInfo ,AD} = Turing. LogDensityFunction{
103
- M,V,C,AD
104
+ struct OptimLogDensity{
105
+ M<: DynamicPPL.Model ,
106
+ V<: DynamicPPL.VarInfo ,
107
+ C<: OptimizationContext ,
108
+ AD<: ADTypes.AbstractADType ,
104
109
}
110
+ ldf:: Turing.LogDensityFunction{M,V,C,AD}
111
+ end
105
112
106
- """
107
- OptimLogDensity(model::DynamicPPL.Model, context::OptimizationContext)
113
+ function OptimLogDensity (
114
+ model:: DynamicPPL.Model ,
115
+ vi:: DynamicPPL.VarInfo ,
116
+ ctx:: OptimizationContext ;
117
+ adtype:: Union{Nothing,ADTypes.AbstractADType} = AutoForwardDiff (),
118
+ )
119
+ return OptimLogDensity (Turing. LogDensityFunction (model, vi, ctx; adtype= adtype))
120
+ end
108
121
109
- Create a callable `OptimLogDensity` struct that evaluates a model using the given `context`.
110
- """
111
- function OptimLogDensity (model:: DynamicPPL.Model , context:: OptimizationContext )
112
- init = DynamicPPL. VarInfo (model)
113
- return Turing. LogDensityFunction (model, init, context)
122
+ # No varinfo
123
+ function OptimLogDensity (
124
+ model:: DynamicPPL.Model ,
125
+ ctx:: OptimizationContext ;
126
+ adtype:: Union{Nothing,ADTypes.AbstractADType} = AutoForwardDiff (),
127
+ )
128
+ return OptimLogDensity (
129
+ Turing. LogDensityFunction (model, DynamicPPL. VarInfo (model), ctx; adtype= adtype)
130
+ )
114
131
end
115
132
116
133
"""
@@ -123,40 +140,30 @@ depends on the context of `f`.
123
140
Any second argument is ignored. The two-argument method only exists to match interface the
124
141
required by Optimization.jl.
125
142
"""
126
- function (f:: OptimLogDensity )(z:: AbstractVector )
127
- varinfo = DynamicPPL. unflatten (f. varinfo, z)
128
- return - DynamicPPL. getlogp (last (DynamicPPL. evaluate!! (f. model, varinfo, f. context)))
129
- end
130
-
143
+ (f:: OptimLogDensity )(z:: AbstractVector ) = - LogDensityProblems. logdensity (f. ldf, z)
131
144
(f:: OptimLogDensity )(z, _) = f (z)
132
145
133
- # NOTE: This seems a bit weird IMO since this is the _negative_ log-likelihood.
134
- LogDensityProblems. logdensity (f:: OptimLogDensity , z:: AbstractVector ) = f (z)
135
-
136
146
# NOTE: The format of this function is dictated by Optim. The first argument sets whether to
137
147
# compute the function value, the second whether to compute the gradient (and stores the
138
148
# gradient). The last one is the actual argument of the objective function.
139
149
function (f:: OptimLogDensity )(F, G, z)
140
150
if G != = nothing
141
- # Calculate negative log joint and its gradient.
142
- # TODO : Make OptimLogDensity already an LogDensityProblems.ADgradient? Allow to
143
- # specify AD?
144
- ℓ = LogDensityProblemsAD. ADgradient (f)
145
- neglogp, ∇neglogp = LogDensityProblems. logdensity_and_gradient (ℓ, z)
151
+ # Calculate log joint and its gradient.
152
+ logp, ∇logp = LogDensityProblems. logdensity_and_gradient (f. ldf, z)
146
153
147
- # Save the gradient to the pre-allocated array.
148
- copyto! (G, ∇neglogp )
154
+ # Save the negative gradient to the pre-allocated array.
155
+ copyto! (G, - ∇logp )
149
156
150
157
# If F is something, the negative log joint is requested as well.
151
158
# We have already computed it as a by-product above and hence return it directly.
152
159
if F != = nothing
153
- return neglogp
160
+ return - logp
154
161
end
155
162
end
156
163
157
164
# Only negative log joint requested but no gradient.
158
165
if F != = nothing
159
- return LogDensityProblems. logdensity (f, z)
166
+ return - LogDensityProblems. logdensity (f. ldf , z)
160
167
end
161
168
162
169
return nothing
@@ -232,9 +239,11 @@ function StatsBase.informationmatrix(
232
239
233
240
# Convert the values to their unconstrained states to make sure the
234
241
# Hessian is computed with respect to the untransformed parameters.
235
- linked = DynamicPPL. istrans (m. f. varinfo)
242
+ linked = DynamicPPL. istrans (m. f. ldf . varinfo)
236
243
if linked
237
- m = Accessors. @set m. f. varinfo = DynamicPPL. invlink!! (m. f. varinfo, m. f. model)
244
+ new_vi = DynamicPPL. invlink!! (m. f. ldf. varinfo, m. f. ldf. model)
245
+ new_f = OptimLogDensity (m. f. ldf. model, new_vi, m. f. ldf. context)
246
+ m = Accessors. @set m. f = new_f
238
247
end
239
248
240
249
# Calculate the Hessian, which is the information matrix because the negative of the log
@@ -244,7 +253,9 @@ function StatsBase.informationmatrix(
244
253
245
254
# Link it back if we invlinked it.
246
255
if linked
247
- m = Accessors. @set m. f. varinfo = DynamicPPL. link!! (m. f. varinfo, m. f. model)
256
+ new_vi = DynamicPPL. link!! (m. f. ldf. varinfo, m. f. ldf. model)
257
+ new_f = OptimLogDensity (m. f. ldf. model, new_vi, m. f. ldf. context)
258
+ m = Accessors. @set m. f = new_f
248
259
end
249
260
250
261
return NamedArrays. NamedArray (info, (varnames, varnames))
@@ -265,7 +276,7 @@ Return the values of all the variables with the symbol(s) `var_symbol` in the mo
265
276
argument should be either a `Symbol` or a vector of `Symbol`s.
266
277
"""
267
278
function Base. get (m:: ModeResult , var_symbols:: AbstractVector{Symbol} )
268
- log_density = m. f
279
+ log_density = m. f. ldf
269
280
# Get all the variable names in the model. This is the same as the list of keys in
270
281
# m.values, but they are more convenient to filter when they are VarNames rather than
271
282
# Symbols.
@@ -297,9 +308,9 @@ richer format of `ModeResult`. It also takes care of transforming them back to t
297
308
parameter space in case the optimization was done in a transformed space.
298
309
"""
299
310
function ModeResult (log_density:: OptimLogDensity , solution:: SciMLBase.OptimizationSolution )
300
- varinfo_new = DynamicPPL. unflatten (log_density. varinfo, solution. u)
311
+ varinfo_new = DynamicPPL. unflatten (log_density. ldf . varinfo, solution. u)
301
312
# `getparams` performs invlinking if needed
302
- vns_vals_iter = Turing. Inference. getparams (log_density. model, varinfo_new)
313
+ vns_vals_iter = Turing. Inference. getparams (log_density. ldf . model, varinfo_new)
303
314
syms = map (Symbol ∘ first, vns_vals_iter)
304
315
vals = map (last, vns_vals_iter)
305
316
return ModeResult (
@@ -383,12 +394,15 @@ end
383
394
OptimizationProblem(log_density::OptimLogDensity, adtype, constraints)
384
395
385
396
Create an `OptimizationProblem` for the objective function defined by `log_density`.
397
+
398
+ Note that the adtype parameter here overrides any adtype parameter the
399
+ OptimLogDensity was constructed with.
386
400
"""
387
401
function Optimization. OptimizationProblem (log_density:: OptimLogDensity , adtype, constraints)
388
402
# Note that OptimLogDensity is a callable that evaluates the model with given
389
403
# parameters. Hence we can use it in the objective function as below.
390
404
f = Optimization. OptimizationFunction (log_density, adtype; cons= constraints. cons)
391
- initial_params = log_density. varinfo[:]
405
+ initial_params = log_density. ldf . varinfo[:]
392
406
prob = if ! has_constraints (constraints)
393
407
Optimization. OptimizationProblem (f, initial_params)
394
408
else
@@ -454,28 +468,34 @@ function estimate_mode(
454
468
end
455
469
456
470
# Create an OptimLogDensity object that can be used to evaluate the objective function,
457
- # i.e. the negative log density. Set its VarInfo to the initial parameters.
458
- log_density = let
459
- inner_context = if estimator isa MAP
460
- DynamicPPL. DefaultContext ()
461
- else
462
- DynamicPPL. LikelihoodContext ()
463
- end
464
- ctx = OptimizationContext (inner_context)
465
- ld = OptimLogDensity (model, ctx)
466
- Accessors. @set ld. varinfo = DynamicPPL. unflatten (ld. varinfo, initial_params)
471
+ # i.e. the negative log density.
472
+ inner_context = if estimator isa MAP
473
+ DynamicPPL. DefaultContext ()
474
+ else
475
+ DynamicPPL. LikelihoodContext ()
467
476
end
477
+ ctx = OptimizationContext (inner_context)
468
478
479
+ # Set its VarInfo to the initial parameters.
480
+ # TODO (penelopeysm): Unclear if this is really needed? Any time that logp is calculated
481
+ # (using `LogDensityProblems.logdensity(ldf, x)`) the parameters in the
482
+ # varinfo are completely ignored. The parameters only matter if you are calling evaluate!!
483
+ # directly on the fields of the LogDensityFunction
484
+ vi = DynamicPPL. VarInfo (model)
485
+ vi = DynamicPPL. unflatten (vi, initial_params)
486
+
487
+ # Link the varinfo if needed.
469
488
# TODO (mhauru) We currently couple together the questions of whether the user specified
470
489
# bounds/constraints and whether we transform the objective function to an
471
490
# unconstrained space. These should be separate concerns, but for that we need to
472
491
# implement getting the bounds of the prior distributions.
473
492
optimise_in_unconstrained_space = ! has_constraints (constraints)
474
493
if optimise_in_unconstrained_space
475
- transformed_varinfo = DynamicPPL. link (log_density. varinfo, log_density. model)
476
- log_density = Accessors. @set log_density. varinfo = transformed_varinfo
494
+ vi = DynamicPPL. link (vi, model)
477
495
end
478
496
497
+ log_density = OptimLogDensity (model, vi, ctx)
498
+
479
499
prob = Optimization. OptimizationProblem (log_density, adtype, constraints)
480
500
solution = Optimization. solve (prob, solver; kwargs... )
481
501
# TODO (mhauru) We return a ModeResult for compatibility with the older Optim.jl
0 commit comments