@@ -17,7 +17,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true
17
17
"""
18
18
LogDensityFunction(
19
19
model::Model,
20
- varinfo::AbstractVarInfo=VarInfo(model),
20
+ getlogdensity::Function=getlogjoint,
21
+ varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity),
21
22
context::AbstractContext=DefaultContext();
22
23
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
23
24
)
@@ -28,10 +29,10 @@ A struct which contains a model, along with all the information necessary to:
28
29
- and if `adtype` is provided, calculate the gradient of the log density at
29
30
that point.
30
31
31
- At its most basic level, a LogDensityFunction wraps the model together with its
32
- the type of varinfo to be used, as well as the evaluation context. These must
33
- be known in order to calculate the log density (using
34
- [`DynamicPPL.evaluate!!`](@ref)).
32
+ At its most basic level, a LogDensityFunction wraps the model together with
33
+ the type of varinfo to be used, as well as the evaluation context and a function
34
+ to extract the log density from the VarInfo. These must be known in order to
35
+ calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).
35
36
36
37
If the `adtype` keyword argument is provided, then this struct will also store
37
38
the adtype along with other information for efficient calculation of the
@@ -73,13 +74,13 @@ julia> LogDensityProblems.dimension(f)
73
74
1
74
75
75
76
julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
76
- f = LogDensityFunction(model, SimpleVarInfo(model));
77
+ f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));
77
78
78
79
julia> LogDensityProblems.logdensity(f, [0.0])
79
80
-2.3378770664093453
80
81
81
- julia> # LogDensityFunction respects the accumulators in VarInfo :
82
- f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) );
82
+ julia> # One can also specify evaluating e.g. the log prior only :
83
+ f_prior = LogDensityFunction(model, getprior );
83
84
84
85
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
85
86
true
@@ -94,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
94
95
```
95
96
"""
96
97
struct LogDensityFunction{
97
- M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext ,AD<: Union{Nothing,ADTypes.AbstractADType}
98
+ M<: Model ,F <: Function , V<: AbstractVarInfo ,C<: AbstractContext ,AD<: Union{Nothing,ADTypes.AbstractADType}
98
99
}
99
100
" model used for evaluation"
100
101
model:: M
101
- " varinfo used for evaluation"
102
+ " function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
103
+ getlogdensity:: F
104
+ " varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
102
105
varinfo:: V
103
106
" context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
104
107
context:: C
@@ -109,7 +112,8 @@ struct LogDensityFunction{
109
112
110
113
function LogDensityFunction (
111
114
model:: Model ,
112
- varinfo:: AbstractVarInfo = VarInfo (model),
115
+ getlogdensity:: Function = getlogjoint,
116
+ varinfo:: AbstractVarInfo = ldf_default_varinfo (model, getlogdensity),
113
117
context:: AbstractContext = leafcontext (model. context);
114
118
adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
115
119
)
@@ -125,21 +129,22 @@ struct LogDensityFunction{
125
129
x = map (identity, varinfo[:])
126
130
if use_closure (adtype)
127
131
prep = DI. prepare_gradient (
128
- x -> logdensity_at (x, model, varinfo, context), adtype, x
132
+ x -> logdensity_at (x, model, getlogdensity, varinfo, context), adtype, x
129
133
)
130
134
else
131
135
prep = DI. prepare_gradient (
132
136
logdensity_at,
133
137
adtype,
134
138
x,
135
139
DI. Constant (model),
140
+ DI. Constant (getlogdensity),
136
141
DI. Constant (varinfo),
137
142
DI. Constant (context),
138
143
)
139
144
end
140
145
end
141
- return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype)} (
142
- model, varinfo, context, adtype, prep
146
+ return new {typeof(model),typeof(getlogdensity),typeof( varinfo),typeof(context),typeof(adtype)} (
147
+ model, getlogdensity, varinfo, context, adtype, prep
143
148
)
144
149
end
145
150
end
@@ -164,64 +169,80 @@ function LogDensityFunction(
164
169
end
165
170
end
166
171
172
+ """
173
+ ldf_default_varinfo(model::Model, getlogdensity::Function)
174
+
175
+ Create the default AbstractVarInfo that should be used for evaluating the log density.
176
+
177
+ Only the accumulators necesessary for `getlogdensity` will be used.
178
+ """
179
+ function ldf_default_varinfo (:: Model , getlogdensity:: Function )
180
+ msg = """
181
+ LogDensityFunction does not know what sort of VarInfo should be used when \
182
+ `getlogdensity` is $getlogdensity . Please specify a VarInfo explicitly.
183
+ """
184
+ error (msg)
185
+ end
186
+
187
+ ldf_default_varinfo (model:: Model , :: typeof (getlogjoint)) = VarInfo (model)
188
+
189
+ function ldf_default_varinfo (model:: Model , :: typeof (getlogprior))
190
+ return setaccs!! (VarInfo (model), (LogPriorAccumulator (),))
191
+ end
192
+
193
+ function ldf_default_varinfo (model:: Model , :: typeof (getloglikelihood))
194
+ return setaccs!! (VarInfo (model), (LogLikelihoodAccumulator (),))
195
+ end
196
+
167
197
"""
168
198
logdensity_at(
169
199
x::AbstractVector,
170
200
model::Model,
201
+ getlogdensity::Function,
171
202
varinfo::AbstractVarInfo,
172
203
context::AbstractContext
173
204
)
174
205
175
206
Evaluate the log density of the given `model` at the given parameter values `x`,
176
207
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
177
208
only for its structure, in the sense that the parameters from the vector `x` are inserted
178
- into it, and its own parameters are discarded. It does, however, determine whether the log
179
- prior, likelihood, or joint is returned, based on which accumulators are set in it .
209
+ into it, and its own parameters are discarded. `getlogdensity` is the function that extracts
210
+ the log density from the evaluated varinfo .
180
211
"""
181
212
function logdensity_at (
182
- x:: AbstractVector , model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext
213
+ x:: AbstractVector , model:: Model , getlogdensity :: Function , varinfo:: AbstractVarInfo , context:: AbstractContext
183
214
)
184
215
varinfo_new = unflatten (varinfo, x)
185
216
varinfo_eval = last (evaluate!! (model, varinfo_new, context))
186
- has_prior = hasacc (varinfo_eval, Val (:LogPrior ))
187
- has_likelihood = hasacc (varinfo_eval, Val (:LogLikelihood ))
188
- if has_prior && has_likelihood
189
- return getlogjoint (varinfo_eval)
190
- elseif has_prior
191
- return getlogprior (varinfo_eval)
192
- elseif has_likelihood
193
- return getloglikelihood (varinfo_eval)
194
- else
195
- error (" LogDensityFunction: varinfo tracks neither log prior nor log likelihood" )
196
- end
217
+ return getlogdensity (varinfo_eval)
197
218
end
198
219
199
220
# ## LogDensityProblems interface
200
221
201
222
function LogDensityProblems. capabilities (
202
- :: Type{<:LogDensityFunction{M,V,C,Nothing}}
203
- ) where {M,V,C}
223
+ :: Type{<:LogDensityFunction{M,F, V,C,Nothing}}
224
+ ) where {M,F, V,C}
204
225
return LogDensityProblems. LogDensityOrder {0} ()
205
226
end
206
227
function LogDensityProblems. capabilities (
207
- :: Type{<:LogDensityFunction{M,V,C,AD}}
208
- ) where {M,V,C,AD<: ADTypes.AbstractADType }
228
+ :: Type{<:LogDensityFunction{M,F, V,C,AD}}
229
+ ) where {M,F, V,C,AD<: ADTypes.AbstractADType }
209
230
return LogDensityProblems. LogDensityOrder {1} ()
210
231
end
211
232
function LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
212
- return logdensity_at (x, f. model, f. varinfo, f. context)
233
+ return logdensity_at (x, f. model, f. getlogdensity, f . varinfo, f. context)
213
234
end
214
235
function LogDensityProblems. logdensity_and_gradient (
215
- f:: LogDensityFunction{M,V,C,AD} , x:: AbstractVector
216
- ) where {M,V,C,AD<: ADTypes.AbstractADType }
236
+ f:: LogDensityFunction{M,F, V,C,AD} , x:: AbstractVector
237
+ ) where {M,F, V,C,AD<: ADTypes.AbstractADType }
217
238
f. prep === nothing &&
218
239
error (" Gradient preparation not available; this should not happen" )
219
240
x = map (identity, x) # Concretise type
220
241
# Make branching statically inferrable, i.e. type-stable (even if the two
221
242
# branches happen to return different types)
222
243
return if use_closure (f. adtype)
223
244
DI. value_and_gradient (
224
- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
245
+ x -> logdensity_at (x, f. model, f. getlogdensity, f . varinfo, f. context), f. prep, f. adtype, x
225
246
)
226
247
else
227
248
DI. value_and_gradient (
@@ -230,6 +251,7 @@ function LogDensityProblems.logdensity_and_gradient(
230
251
f. adtype,
231
252
x,
232
253
DI. Constant (f. model),
254
+ DI. Constant (f. getlogdensity),
233
255
DI. Constant (f. varinfo),
234
256
DI. Constant (f. context),
235
257
)
@@ -304,7 +326,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
304
326
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
305
327
"""
306
328
function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
307
- return LogDensityFunction (model, f. varinfo, f. context; adtype= f. adtype)
329
+ return LogDensityFunction (model, f. getlogdensity, f . varinfo, f. context; adtype= f. adtype)
308
330
end
309
331
310
332
"""
0 commit comments