78
78
"""
79
79
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}
80
80
81
- A struct that stores the log density function of a `DynamicPPL` model.
81
+ A struct that stores the negative log density function of a `DynamicPPL` model.
82
82
"""
83
- struct OptimLogDensity{M<: Model ,C<: AbstractContext ,V<: VarInfo }
84
- " A `DynamicPPL.Model` constructed either with the `@model` macro or manually."
85
- model:: M
86
- " A `DynamicPPL.AbstractContext` used to evaluate the model. `LikelihoodContext` or `DefaultContext` are typical for MAP/MLE."
87
- context:: C
88
- " A `DynamicPPL.VarInfo` struct that will be used to update model parameters."
89
- vi:: V
90
- end
83
+ const OptimLogDensity{M<: Model ,C<: OptimizationContext ,V<: VarInfo } = Turing. LogDensityFunction{V,M,DynamicPPL. SampleFromPrior,C}
91
84
92
85
"""
93
- OptimLogDensity(model::Model, context::AbstractContext )
86
+ OptimLogDensity(model::Model, context::OptimizationContext )
94
87
95
88
Create a callable `OptimLogDensity` struct that evaluates a model using the given `context`.
96
89
"""
97
- function OptimLogDensity (model:: Model , context:: AbstractContext )
90
+ function OptimLogDensity (model:: Model , context:: OptimizationContext )
98
91
init = VarInfo (model)
99
- return OptimLogDensity ( model, context, init )
92
+ return Turing . LogDensityFunction (init, model, DynamicPPL . SampleFromPrior (), context )
100
93
end
101
94
102
95
"""
103
96
(f::OptimLogDensity)(z)
104
97
105
- Evaluate the log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
98
+ Evaluate the negative log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
106
99
at the array `z`.
107
100
"""
108
- function (f:: OptimLogDensity )(z)
109
- spl = DynamicPPL. SampleFromPrior ()
110
-
111
- varinfo = DynamicPPL. VarInfo (f. vi, spl, z)
112
- f. model (varinfo, spl, f. context)
113
- return - DynamicPPL. getlogp (varinfo)
101
+ function (f:: OptimLogDensity )(z:: AbstractVector )
102
+ sampler = f. sampler
103
+ varinfo = DynamicPPL. VarInfo (f. varinfo, sampler, z)
104
+ return - getlogp (last (DynamicPPL. evaluate!! (f. model, varinfo, sampler, f. context)))
114
105
end
115
106
116
- function (f:: OptimLogDensity )(F, G, H, z)
117
- # Throw an error if a second order method was used.
118
- if H != = nothing
119
- error (" Second order optimization is not yet supported." )
120
- end
121
-
122
- spl = DynamicPPL. SampleFromPrior ()
123
-
107
+ function (f:: OptimLogDensity )(F, G, z)
124
108
if G != = nothing
125
- # Calculate log joint and the gradient
126
- l, g = Turing. gradient_logp (
109
+ # Calculate negative log joint and its gradient.
110
+ sampler = f. sampler
111
+ neglogp, ∇neglogp = Turing. gradient_logp (
127
112
z,
128
- DynamicPPL. VarInfo (f. vi, spl , z),
113
+ DynamicPPL. VarInfo (f. varinfo, sampler , z),
129
114
f. model,
130
- spl ,
131
- f. context
115
+ sampler ,
116
+ f. context,
132
117
)
133
118
134
- # Use the negative gradient because we are minimizing .
135
- G[:] = - g
119
+ # Save the gradient to the pre-allocated array .
120
+ copyto! (G, ∇neglogp)
136
121
137
- # If F is something, return that since we already have the
138
- # log joint .
122
+ # If F is something, the negative log joint is requested as well.
123
+ # We have already computed it as a by-product above and hence return it directly .
139
124
if F != = nothing
140
- F = - l
141
- return F
125
+ return neglogp
142
126
end
143
127
end
144
128
145
- # No gradient necessary, just return the log joint .
129
+ # Only negative log joint requested but no gradient .
146
130
if F != = nothing
147
- F = f (z)
148
- return F
131
+ return f (z)
149
132
end
150
133
151
134
return nothing
@@ -158,16 +141,16 @@ end
158
141
# ################################################
159
142
160
143
function transform! (f:: OptimLogDensity )
161
- spl = DynamicPPL . SampleFromPrior ()
144
+ spl = f . sampler
162
145
163
146
# # Check link status of vi in OptimLogDensity
164
- linked = DynamicPPL. islinked (f. vi , spl)
147
+ linked = DynamicPPL. islinked (f. varinfo , spl)
165
148
166
149
# # transform into constrained or unconstrained space depending on current state of vi
167
150
if ! linked
168
- DynamicPPL. link! (f. vi , spl)
151
+ DynamicPPL. link! (f. varinfo , spl)
169
152
else
170
- DynamicPPL. invlink! (f. vi , spl)
153
+ DynamicPPL. invlink! (f. varinfo , spl)
171
154
end
172
155
173
156
return nothing
@@ -249,8 +232,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MAP, ::constrained_space{fa
249
232
obj = OptimLogDensity (model, ctx)
250
233
251
234
transform! (obj)
252
- init = Init (obj. vi , constrained_space {false} ())
253
- t = ParameterTransform (obj. vi , constrained_space {true} ())
235
+ init = Init (obj. varinfo , constrained_space {false} ())
236
+ t = ParameterTransform (obj. varinfo , constrained_space {true} ())
254
237
255
238
return (obj= obj, init = init, transform= t)
256
239
end
@@ -259,8 +242,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MAP, ::constrained_space{tr
259
242
ctx = OptimizationContext (DynamicPPL. DefaultContext ())
260
243
obj = OptimLogDensity (model, ctx)
261
244
262
- init = Init (obj. vi , constrained_space {true} ())
263
- t = ParameterTransform (obj. vi , constrained_space {true} ())
245
+ init = Init (obj. varinfo , constrained_space {true} ())
246
+ t = ParameterTransform (obj. varinfo , constrained_space {true} ())
264
247
265
248
return (obj= obj, init = init, transform= t)
266
249
end
@@ -270,8 +253,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MLE, ::constrained_space{f
270
253
obj = OptimLogDensity (model, ctx)
271
254
272
255
transform! (obj)
273
- init = Init (obj. vi , constrained_space {false} ())
274
- t = ParameterTransform (obj. vi , constrained_space {true} ())
256
+ init = Init (obj. varinfo , constrained_space {false} ())
257
+ t = ParameterTransform (obj. varinfo , constrained_space {true} ())
275
258
276
259
return (obj= obj, init = init, transform= t)
277
260
end
@@ -280,8 +263,8 @@ function _optim_objective(model::DynamicPPL.Model, ::MLE, ::constrained_space{tr
280
263
ctx = OptimizationContext (DynamicPPL. LikelihoodContext ())
281
264
obj = OptimLogDensity (model, ctx)
282
265
283
- init = Init (obj. vi , constrained_space {true} ())
284
- t = ParameterTransform (obj. vi , constrained_space {true} ())
266
+ init = Init (obj. varinfo , constrained_space {true} ())
267
+ t = ParameterTransform (obj. varinfo , constrained_space {true} ())
285
268
286
269
return (obj= obj, init = init, transform= t)
287
270
end
@@ -309,8 +292,7 @@ function optim_function(
309
292
else
310
293
OptimizationFunction (
311
294
l;
312
- grad = (G,x,p) -> obj (nothing , G, nothing , x),
313
- hess = (H,x,p) -> obj (nothing , nothing , H, x),
295
+ grad = (G,x,p) -> obj (nothing , G, x),
314
296
)
315
297
end
316
298
0 commit comments