@@ -4,8 +4,9 @@ import DifferentiationInterface as DI
4
4
LogDensityFunction(
5
5
model::Model,
6
6
varinfo::AbstractVarInfo=VarInfo(model),
7
- context::AbstractContext=DefaultContext(),
8
- adtype::Union{ADTypes.AbstractADType,Nothing}=nothing)
7
+ context::AbstractContext=DefaultContext();
8
+ adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
9
+ )
9
10
10
11
A struct which contains a model, along with all the information necessary to:
11
12
@@ -16,8 +17,9 @@ At its most basic level, a LogDensityFunction wraps the model together with its
16
17
the type of varinfo to be used, as well as the evaluation context. These must
17
18
be known in order to calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).
18
19
19
- If `adtype` is provided, then this struct will also contain the adtype along with
20
- other information for efficient calculation of the gradient of the log density.
20
+ If the `adtype` keyword argument is provided, then this struct will also
21
+ store the adtype along with other information for efficient calculation of the
22
+ gradient of the log density.
21
23
22
24
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
23
25
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
@@ -27,6 +29,7 @@ concrete AD backend type, then `logdensity_and_gradient` is also implemented.
27
29
$(FIELDS)
28
30
29
31
# Examples
32
+
30
33
```jldoctest
31
34
julia> using Distributions
32
35
@@ -62,18 +65,26 @@ julia> # This also respects the context in `model`.
62
65
63
66
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
64
67
true
68
+
69
+ julia> # If we also need to calculate the gradient, we can specify an AD backend.
70
+ import ForwardDiff, ADTypes
71
+
72
+ julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
73
+
74
+ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
75
+ (-2.3378770664093453, [1.0])
65
76
```
66
77
"""
67
78
struct LogDensityFunction{
68
- V<: AbstractVarInfo ,
69
79
M<: Model ,
70
- C<: Union{Nothing,AbstractContext} ,
80
+ V<: AbstractVarInfo ,
81
+ C<: AbstractContext ,
71
82
AD<: Union{Nothing,ADTypes.AbstractADType} ,
72
83
}
73
- " varinfo used for evaluation"
74
- varinfo:: V
75
84
" model used for evaluation"
76
85
model:: M
86
+ " varinfo used for evaluation"
87
+ varinfo:: V
77
88
" context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
78
89
context:: C
79
90
" AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
@@ -86,7 +97,7 @@ struct LogDensityFunction{
86
97
function LogDensityFunction (
87
98
model:: Model ,
88
99
varinfo:: AbstractVarInfo = VarInfo (model),
89
- context:: AbstractContext = DefaultContext ( );
100
+ context:: AbstractContext = leafcontext (model . context );
90
101
adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
91
102
)
92
103
if adtype === nothing
@@ -112,8 +123,8 @@ struct LogDensityFunction{
112
123
end
113
124
with_closure = with_closure
114
125
end
115
- return new {typeof(varinfo ),typeof(model ),typeof(context),typeof(adtype)} (
116
- varinfo, model , context, adtype, prep, with_closure
126
+ return new {typeof(model ),typeof(varinfo ),typeof(context),typeof(adtype)} (
127
+ model, varinfo , context, adtype, prep, with_closure
117
128
)
118
129
end
119
130
end
@@ -161,21 +172,21 @@ end
161
172
# ## LogDensityProblems interface
162
173
163
174
function LogDensityProblems. capabilities (
164
- :: Type{<:LogDensityFunction{V,M ,C,Nothing}}
165
- ) where {V,M ,C}
175
+ :: Type{<:LogDensityFunction{M,V ,C,Nothing}}
176
+ ) where {M,V ,C}
166
177
return LogDensityProblems. LogDensityOrder {0} ()
167
178
end
168
179
function LogDensityProblems. capabilities (
169
- :: Type{<:LogDensityFunction{V,M ,C,AD}}
170
- ) where {V,M ,C,AD<: ADTypes.AbstractADType }
180
+ :: Type{<:LogDensityFunction{M,V ,C,AD}}
181
+ ) where {M,V ,C,AD<: ADTypes.AbstractADType }
171
182
return LogDensityProblems. LogDensityOrder {1} ()
172
183
end
173
184
function LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
174
185
return logdensity_at (x, f. model, f. varinfo, f. context)
175
186
end
176
187
function LogDensityProblems. logdensity_and_gradient (
177
- f:: LogDensityFunction{V,M ,C,AD} , x:: AbstractVector
178
- ) where {V,M ,C,AD<: ADTypes.AbstractADType }
188
+ f:: LogDensityFunction{M,V ,C,AD} , x:: AbstractVector
189
+ ) where {M,V ,C,AD<: ADTypes.AbstractADType }
179
190
f. prep === nothing &&
180
191
error (" Gradient preparation not available; this should not happen" )
181
192
x = map (identity, x) # Concretise type
@@ -231,11 +242,6 @@ use_closure(::ADTypes.AutoForwardDiff) = false
231
242
use_closure (:: ADTypes.AutoMooncake ) = false
232
243
use_closure (:: ADTypes.AutoReverseDiff ) = true
233
244
234
- # If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
235
- function getcontext (f:: LogDensityFunction )
236
- return f. context === nothing ? leafcontext (f. model. context) : f. context
237
- end
238
-
239
245
"""
240
246
getmodel(f)
241
247
@@ -249,7 +255,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
249
255
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
250
256
"""
251
257
function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
252
- return Accessors . @set f. model = model
258
+ return LogDensityFunction (model, f . varinfo, f. context; adtype = f . adtype)
253
259
end
254
260
255
261
"""
0 commit comments