Skip to content

Commit 892b097

Browse files
committed
Fix tests
1 parent aef4b91 commit 892b097

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

src/logdensityfunction.jl

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import DifferentiationInterface as DI
44
LogDensityFunction(
55
model::Model,
66
varinfo::AbstractVarInfo=VarInfo(model),
7-
context::AbstractContext=DefaultContext(),
8-
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing)
7+
context::AbstractContext=DefaultContext();
8+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
9+
)
910
1011
A struct which contains a model, along with all the information necessary to:
1112
@@ -16,8 +17,9 @@ At its most basic level, a LogDensityFunction wraps the model together with its
1617
the type of varinfo to be used, as well as the evaluation context. These must
1718
be known in order to calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).
1819
19-
If `adtype` is provided, then this struct will also contain the adtype along with
20-
other information for efficient calculation of the gradient of the log density.
20+
If the `adtype` keyword argument is provided, then this struct will also
21+
store the adtype along with other information for efficient calculation of the
22+
gradient of the log density.
2123
2224
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
2325
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
@@ -27,6 +29,7 @@ concrete AD backend type, then `logdensity_and_gradient` is also implemented.
2729
$(FIELDS)
2830
2931
# Examples
32+
3033
```jldoctest
3134
julia> using Distributions
3235
@@ -62,18 +65,23 @@ julia> # This also respects the context in `model`.
6265
6366
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
6467
true
68+
69+
julia> # If we also need to calculate the gradient, we can specify an AD backend.
70+
import ForwardDiff, ADTypes
71+
72+
julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
73+
74+
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
75+
(-2.3378770664093453, [1.0])
6576
```
6677
"""
6778
struct LogDensityFunction{
68-
V<:AbstractVarInfo,
69-
M<:Model,
70-
C<:Union{Nothing,AbstractContext},
71-
AD<:Union{Nothing,ADTypes.AbstractADType},
79+
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
7280
}
73-
"varinfo used for evaluation"
74-
varinfo::V
7581
"model used for evaluation"
7682
model::M
83+
"varinfo used for evaluation"
84+
varinfo::V
7785
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
7886
context::C
7987
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
@@ -86,7 +94,7 @@ struct LogDensityFunction{
8694
function LogDensityFunction(
8795
model::Model,
8896
varinfo::AbstractVarInfo=VarInfo(model),
89-
context::AbstractContext=DefaultContext();
97+
context::AbstractContext=leafcontext(model.context);
9098
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
9199
)
92100
if adtype === nothing
@@ -112,8 +120,8 @@ struct LogDensityFunction{
112120
end
113121
with_closure = with_closure
114122
end
115-
return new{typeof(varinfo),typeof(model),typeof(context),typeof(adtype)}(
116-
varinfo, model, context, adtype, prep, with_closure
123+
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
124+
model, varinfo, context, adtype, prep, with_closure
117125
)
118126
end
119127
end
@@ -161,21 +169,21 @@ end
161169
### LogDensityProblems interface
162170

163171
function LogDensityProblems.capabilities(
164-
::Type{<:LogDensityFunction{V,M,C,Nothing}}
165-
) where {V,M,C}
172+
::Type{<:LogDensityFunction{M,V,C,Nothing}}
173+
) where {M,V,C}
166174
return LogDensityProblems.LogDensityOrder{0}()
167175
end
168176
function LogDensityProblems.capabilities(
169-
::Type{<:LogDensityFunction{V,M,C,AD}}
170-
) where {V,M,C,AD<:ADTypes.AbstractADType}
177+
::Type{<:LogDensityFunction{M,V,C,AD}}
178+
) where {M,V,C,AD<:ADTypes.AbstractADType}
171179
return LogDensityProblems.LogDensityOrder{1}()
172180
end
173181
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
174182
return logdensity_at(x, f.model, f.varinfo, f.context)
175183
end
176184
function LogDensityProblems.logdensity_and_gradient(
177-
f::LogDensityFunction{V,M,C,AD}, x::AbstractVector
178-
) where {V,M,C,AD<:ADTypes.AbstractADType}
185+
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
186+
) where {M,V,C,AD<:ADTypes.AbstractADType}
179187
f.prep === nothing &&
180188
error("Gradient preparation not available; this should not happen")
181189
x = map(identity, x) # Concretise type
@@ -231,11 +239,6 @@ use_closure(::ADTypes.AutoForwardDiff) = false
231239
use_closure(::ADTypes.AutoMooncake) = false
232240
use_closure(::ADTypes.AutoReverseDiff) = true
233241

234-
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
235-
function getcontext(f::LogDensityFunction)
236-
return f.context === nothing ? leafcontext(f.model.context) : f.context
237-
end
238-
239242
"""
240243
getmodel(f)
241244
@@ -249,7 +252,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
249252
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
250253
"""
251254
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
252-
return Accessors.@set f.model = model
255+
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
253256
end
254257

255258
"""

0 commit comments

Comments
 (0)