@@ -18,7 +18,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true
18
18
"""
19
19
LogDensityFunction(
20
20
model::Model,
21
- varinfo::AbstractVarInfo=VarInfo(model);
21
+ getlogdensity::Function=getlogjoint,
22
+ varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
22
23
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
23
24
)
24
25
@@ -28,9 +29,10 @@ A struct which contains a model, along with all the information necessary to:
28
29
- and if `adtype` is provided, calculate the gradient of the log density at
29
30
that point.
30
31
31
- At its most basic level, a LogDensityFunction wraps the model together with the
32
- type of varinfo to be used. These must be known in order to calculate the log
33
- density (using [`DynamicPPL.evaluate!!`](@ref)).
32
+ At its most basic level, a LogDensityFunction wraps the model together with a
33
+ function that specifies how to extract the log density, and the type of
34
+ VarInfo to be used. These must be known in order to calculate the log density
35
+ (using [`DynamicPPL.evaluate!!`](@ref)).
34
36
35
37
If the `adtype` keyword argument is provided, then this struct will also store
36
38
the adtype along with other information for efficient calculation of the
@@ -72,13 +74,13 @@ julia> LogDensityProblems.dimension(f)
72
74
1
73
75
74
76
julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
75
- f = LogDensityFunction(model, SimpleVarInfo(model));
77
+ f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));
76
78
77
79
julia> LogDensityProblems.logdensity(f, [0.0])
78
80
-2.3378770664093453
79
81
80
- julia> # LogDensityFunction respects the accumulators in VarInfo :
81
- f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) );
82
+ julia> # One can also specify evaluating e.g. the log prior only :
83
+ f_prior = LogDensityFunction(model, getlogprior );
82
84
83
85
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
84
86
true
@@ -93,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
93
95
```
94
96
"""
95
97
struct LogDensityFunction{
96
- M<: Model ,V<: AbstractVarInfo ,AD<: Union{Nothing,ADTypes.AbstractADType}
98
+ M<: Model ,F <: Function , V<: AbstractVarInfo ,AD<: Union{Nothing,ADTypes.AbstractADType}
97
99
} <: AbstractModel
98
100
" model used for evaluation"
99
101
model:: M
100
- " varinfo used for evaluation"
102
+ " function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
103
+ getlogdensity:: F
104
+ " varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
101
105
varinfo:: V
102
106
" AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
103
107
adtype:: AD
@@ -106,7 +110,8 @@ struct LogDensityFunction{
106
110
107
111
function LogDensityFunction (
108
112
model:: Model ,
109
- varinfo:: AbstractVarInfo = VarInfo (model);
113
+ getlogdensity:: Function = getlogjoint,
114
+ varinfo:: AbstractVarInfo = ldf_default_varinfo (model, getlogdensity);
110
115
adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
111
116
)
112
117
if adtype === nothing
@@ -120,15 +125,22 @@ struct LogDensityFunction{
120
125
# Get a set of dummy params to use for prep
121
126
x = map (identity, varinfo[:])
122
127
if use_closure (adtype)
123
- prep = DI. prepare_gradient (LogDensityAt (model, varinfo), adtype, x)
128
+ prep = DI. prepare_gradient (
129
+ LogDensityAt (model, getlogdensity, varinfo), adtype, x
130
+ )
124
131
else
125
132
prep = DI. prepare_gradient (
126
- logdensity_at, adtype, x, DI. Constant (model), DI. Constant (varinfo)
133
+ logdensity_at,
134
+ adtype,
135
+ x,
136
+ DI. Constant (model),
137
+ DI. Constant (getlogdensity),
138
+ DI. Constant (varinfo),
127
139
)
128
140
end
129
141
end
130
- return new {typeof(model),typeof(varinfo),typeof(adtype)} (
131
- model, varinfo, adtype, prep
142
+ return new {typeof(model),typeof(getlogdensity),typeof( varinfo),typeof(adtype)} (
143
+ model, getlogdensity, varinfo, adtype, prep
132
144
)
133
145
end
134
146
end
@@ -149,83 +161,112 @@ function LogDensityFunction(
149
161
return if adtype === f. adtype
150
162
f # Avoid recomputing prep if not needed
151
163
else
152
- LogDensityFunction (f. model, f. varinfo; adtype= adtype)
164
+ LogDensityFunction (f. model, f. getlogdensity, f . varinfo; adtype= adtype)
153
165
end
154
166
end
155
167
168
+ """
169
+ ldf_default_varinfo(model::Model, getlogdensity::Function)
170
+
171
+ Create the default AbstractVarInfo that should be used for evaluating the log density.
172
+
173
+ Only the accumulators necesessary for `getlogdensity` will be used.
174
+ """
175
+ function ldf_default_varinfo (:: Model , getlogdensity:: Function )
176
+ msg = """
177
+ LogDensityFunction does not know what sort of VarInfo should be used when \
178
+ `getlogdensity` is $getlogdensity . Please specify a VarInfo explicitly.
179
+ """
180
+ return error (msg)
181
+ end
182
+
183
+ ldf_default_varinfo (model:: Model , :: typeof (getlogjoint)) = VarInfo (model)
184
+
185
+ function ldf_default_varinfo (model:: Model , :: typeof (getlogprior))
186
+ return setaccs!! (VarInfo (model), (LogPriorAccumulator (),))
187
+ end
188
+
189
+ function ldf_default_varinfo (model:: Model , :: typeof (getloglikelihood))
190
+ return setaccs!! (VarInfo (model), (LogLikelihoodAccumulator (),))
191
+ end
192
+
156
193
"""
157
194
logdensity_at(
158
195
x::AbstractVector,
159
196
model::Model,
197
+ getlogdensity::Function,
160
198
varinfo::AbstractVarInfo,
161
199
)
162
200
163
- Evaluate the log density of the given `model` at the given parameter values `x`,
164
- using the given `varinfo`. Note that the `varinfo` argument is provided only
165
- for its structure, in the sense that the parameters from the vector `x` are
166
- inserted into it, and its own parameters are discarded. It does, however,
167
- determine whether the log prior, likelihood, or joint is returned, based on
168
- which accumulators are set in it.
201
+ Evaluate the log density of the given `model` at the given parameter values
202
+ `x`, using the given `varinfo`. Note that the `varinfo` argument is provided
203
+ only for its structure, in the sense that the parameters from the vector `x`
204
+ are inserted into it, and its own parameters are discarded. `getlogdensity` is
205
+ the function that extracts the log density from the evaluated varinfo.
169
206
"""
170
- function logdensity_at (x:: AbstractVector , model:: Model , varinfo:: AbstractVarInfo )
207
+ function logdensity_at (
208
+ x:: AbstractVector , model:: Model , getlogdensity:: Function , varinfo:: AbstractVarInfo
209
+ )
171
210
varinfo_new = unflatten (varinfo, x)
172
211
varinfo_eval = last (evaluate!! (model, varinfo_new))
173
- has_prior = hasacc (varinfo_eval, Val (:LogPrior ))
174
- has_likelihood = hasacc (varinfo_eval, Val (:LogLikelihood ))
175
- if has_prior && has_likelihood
176
- return getlogjoint (varinfo_eval)
177
- elseif has_prior
178
- return getlogprior (varinfo_eval)
179
- elseif has_likelihood
180
- return getloglikelihood (varinfo_eval)
181
- else
182
- error (" LogDensityFunction: varinfo tracks neither log prior nor log likelihood" )
183
- end
212
+ return getlogdensity (varinfo_eval)
184
213
end
185
214
186
215
"""
187
- LogDensityAt{M<:Model,V<:AbstractVarInfo}(
216
+ LogDensityAt{M<:Model,F<:Function, V<:AbstractVarInfo}(
188
217
model::M
218
+ getlogdensity::F,
189
219
varinfo::V
190
220
)
191
221
192
222
A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
193
- varinfo)`.
223
+ getlogdensity, varinfo)`.
194
224
"""
195
- struct LogDensityAt{M<: Model ,V<: AbstractVarInfo }
225
+ struct LogDensityAt{M<: Model ,F <: Function , V<: AbstractVarInfo }
196
226
model:: M
227
+ getlogdensity:: F
197
228
varinfo:: V
198
229
end
199
- (ld:: LogDensityAt )(x:: AbstractVector ) = logdensity_at (x, ld. model, ld. varinfo)
230
+ function (ld:: LogDensityAt )(x:: AbstractVector )
231
+ return logdensity_at (x, ld. model, ld. getlogdensity, ld. varinfo)
232
+ end
200
233
201
234
# ## LogDensityProblems interface
202
235
203
236
function LogDensityProblems. capabilities (
204
- :: Type{<:LogDensityFunction{M,V,Nothing}}
205
- ) where {M,V}
237
+ :: Type{<:LogDensityFunction{M,F, V,Nothing}}
238
+ ) where {M,F, V}
206
239
return LogDensityProblems. LogDensityOrder {0} ()
207
240
end
208
241
function LogDensityProblems. capabilities (
209
- :: Type{<:LogDensityFunction{M,V,AD}}
210
- ) where {M,V,AD<: ADTypes.AbstractADType }
242
+ :: Type{<:LogDensityFunction{M,F, V,AD}}
243
+ ) where {M,F, V,AD<: ADTypes.AbstractADType }
211
244
return LogDensityProblems. LogDensityOrder {1} ()
212
245
end
213
246
function LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
214
- return logdensity_at (x, f. model, f. varinfo)
247
+ return logdensity_at (x, f. model, f. getlogdensity, f . varinfo)
215
248
end
216
249
function LogDensityProblems. logdensity_and_gradient (
217
- f:: LogDensityFunction{M,V,AD} , x:: AbstractVector
218
- ) where {M,V,AD<: ADTypes.AbstractADType }
250
+ f:: LogDensityFunction{M,F, V,AD} , x:: AbstractVector
251
+ ) where {M,F, V,AD<: ADTypes.AbstractADType }
219
252
f. prep === nothing &&
220
253
error (" Gradient preparation not available; this should not happen" )
221
254
x = map (identity, x) # Concretise type
222
255
# Make branching statically inferrable, i.e. type-stable (even if the two
223
256
# branches happen to return different types)
224
257
return if use_closure (f. adtype)
225
- DI. value_and_gradient (LogDensityAt (f. model, f. varinfo), f. prep, f. adtype, x)
258
+ DI. value_and_gradient (
259
+ LogDensityAt (f. model, f. getlogdensity, f. varinfo), f. prep, f. adtype, x
260
+ )
226
261
else
227
262
DI. value_and_gradient (
228
- logdensity_at, f. prep, f. adtype, x, DI. Constant (f. model), DI. Constant (f. varinfo)
263
+ logdensity_at,
264
+ f. prep,
265
+ f. adtype,
266
+ x,
267
+ DI. Constant (f. model),
268
+ DI. Constant (f. getlogdensity),
269
+ DI. Constant (f. varinfo),
229
270
)
230
271
end
231
272
end
@@ -264,9 +305,9 @@ There are two ways of dealing with this:
264
305
265
306
1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)
266
307
267
- 2. Use a constant context. This lets us pass a two-argument function to
268
- DifferentiationInterface, as long as we also give it the 'inactive argument'
269
- (i.e. the model) wrapped in `DI.Constant`.
308
+ 2. Use a constant DI.Context. This lets us pass a two-argument function to DI,
309
+ as long as we also give it the 'inactive argument' (i.e. the model) wrapped
310
+ in `DI.Constant`.
270
311
271
312
The relative performance of the two approaches, however, depends on the AD
272
313
backend used. Some benchmarks are provided here:
@@ -292,7 +333,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292
333
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
293
334
"""
294
335
function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
295
- return LogDensityFunction (model, f. varinfo; adtype= f. adtype)
336
+ return LogDensityFunction (model, f. getlogdensity, f . varinfo; adtype= f. adtype)
296
337
end
297
338
298
339
"""
0 commit comments