@@ -4,8 +4,9 @@ import DifferentiationInterface as DI
4
4
LogDensityFunction(
5
5
model::Model,
6
6
varinfo::AbstractVarInfo=VarInfo(model),
7
- context::AbstractContext=DefaultContext(),
8
- adtype::Union{ADTypes.AbstractADType,Nothing}=nothing)
7
+ context::AbstractContext=DefaultContext();
8
+ adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
9
+ )
9
10
10
11
A struct which contains a model, along with all the information necessary to:
11
12
@@ -16,8 +17,9 @@ At its most basic level, a LogDensityFunction wraps the model together with its
16
17
the type of varinfo to be used, as well as the evaluation context. These must
17
18
be known in order to calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).
18
19
19
- If `adtype` is provided, then this struct will also contain the adtype along with
20
- other information for efficient calculation of the gradient of the log density.
20
+ If the `adtype` keyword argument is provided, then this struct will also
21
+ store the adtype along with other information for efficient calculation of the
22
+ gradient of the log density.
21
23
22
24
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
23
25
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
@@ -27,6 +29,7 @@ concrete AD backend type, then `logdensity_and_gradient` is also implemented.
27
29
$(FIELDS)
28
30
29
31
# Examples
32
+
30
33
```jldoctest
31
34
julia> using Distributions
32
35
@@ -62,18 +65,23 @@ julia> # This also respects the context in `model`.
62
65
63
66
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
64
67
true
68
+
69
+ julia> # If we also need to calculate the gradient, we can specify an AD backend.
70
+ import ForwardDiff, ADTypes
71
+
72
+ julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
73
+
74
+ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
75
+ (-2.3378770664093453, [1.0])
65
76
```
66
77
"""
67
78
struct LogDensityFunction{
68
- V<: AbstractVarInfo ,
69
- M<: Model ,
70
- C<: Union{Nothing,AbstractContext} ,
71
- AD<: Union{Nothing,ADTypes.AbstractADType} ,
79
+ M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext ,AD<: Union{Nothing,ADTypes.AbstractADType}
72
80
}
73
- " varinfo used for evaluation"
74
- varinfo:: V
75
81
" model used for evaluation"
76
82
model:: M
83
+ " varinfo used for evaluation"
84
+ varinfo:: V
77
85
" context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
78
86
context:: C
79
87
" AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
@@ -86,7 +94,7 @@ struct LogDensityFunction{
86
94
function LogDensityFunction (
87
95
model:: Model ,
88
96
varinfo:: AbstractVarInfo = VarInfo (model),
89
- context:: AbstractContext = DefaultContext ( );
97
+ context:: AbstractContext = leafcontext (model . context );
90
98
adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
91
99
)
92
100
if adtype === nothing
@@ -112,8 +120,8 @@ struct LogDensityFunction{
112
120
end
113
121
with_closure = with_closure
114
122
end
115
- return new {typeof(varinfo ),typeof(model ),typeof(context),typeof(adtype)} (
116
- varinfo, model , context, adtype, prep, with_closure
123
+ return new {typeof(model ),typeof(varinfo ),typeof(context),typeof(adtype)} (
124
+ model, varinfo , context, adtype, prep, with_closure
117
125
)
118
126
end
119
127
end
@@ -161,21 +169,21 @@ end
161
169
# ## LogDensityProblems interface
162
170
163
171
function LogDensityProblems. capabilities (
164
- :: Type{<:LogDensityFunction{V,M ,C,Nothing}}
165
- ) where {V,M ,C}
172
+ :: Type{<:LogDensityFunction{M,V ,C,Nothing}}
173
+ ) where {M,V ,C}
166
174
return LogDensityProblems. LogDensityOrder {0} ()
167
175
end
168
176
function LogDensityProblems. capabilities (
169
- :: Type{<:LogDensityFunction{V,M ,C,AD}}
170
- ) where {V,M ,C,AD<: ADTypes.AbstractADType }
177
+ :: Type{<:LogDensityFunction{M,V ,C,AD}}
178
+ ) where {M,V ,C,AD<: ADTypes.AbstractADType }
171
179
return LogDensityProblems. LogDensityOrder {1} ()
172
180
end
173
181
function LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
174
182
return logdensity_at (x, f. model, f. varinfo, f. context)
175
183
end
176
184
function LogDensityProblems. logdensity_and_gradient (
177
- f:: LogDensityFunction{V,M ,C,AD} , x:: AbstractVector
178
- ) where {V,M ,C,AD<: ADTypes.AbstractADType }
185
+ f:: LogDensityFunction{M,V ,C,AD} , x:: AbstractVector
186
+ ) where {M,V ,C,AD<: ADTypes.AbstractADType }
179
187
f. prep === nothing &&
180
188
error (" Gradient preparation not available; this should not happen" )
181
189
x = map (identity, x) # Concretise type
@@ -231,11 +239,6 @@ use_closure(::ADTypes.AutoForwardDiff) = false
231
239
use_closure (:: ADTypes.AutoMooncake ) = false
232
240
use_closure (:: ADTypes.AutoReverseDiff ) = true
233
241
234
- # If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
235
- function getcontext (f:: LogDensityFunction )
236
- return f. context === nothing ? leafcontext (f. model. context) : f. context
237
- end
238
-
239
242
"""
240
243
getmodel(f)
241
244
@@ -249,7 +252,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
249
252
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
250
253
"""
251
254
function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
252
- return Accessors . @set f. model = model
255
+ return LogDensityFunction (model, f . varinfo, f. context; adtype = f . adtype)
253
256
end
254
257
255
258
"""
0 commit comments