Skip to content

Commit 8c98a73

Browse files
committed
Fix tests
1 parent aef4b91 commit 8c98a73

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

src/logdensityfunction.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import DifferentiationInterface as DI
44
LogDensityFunction(
55
model::Model,
66
varinfo::AbstractVarInfo=VarInfo(model),
7-
context::AbstractContext=DefaultContext(),
8-
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing)
7+
context::AbstractContext=DefaultContext();
8+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
9+
)
910
1011
A struct which contains a model, along with all the information necessary to:
1112
@@ -16,8 +17,9 @@ At its most basic level, a LogDensityFunction wraps the model together with its
1617
the type of varinfo to be used, as well as the evaluation context. These must
1718
be known in order to calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).
1819
19-
If `adtype` is provided, then this struct will also contain the adtype along with
20-
other information for efficient calculation of the gradient of the log density.
20+
If the `adtype` keyword argument is provided, then this struct will also
21+
store the adtype along with other information for efficient calculation of the
22+
gradient of the log density.
2123
2224
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
2325
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
@@ -27,6 +29,7 @@ concrete AD backend type, then `logdensity_and_gradient` is also implemented.
2729
$(FIELDS)
2830
2931
# Examples
32+
3033
```jldoctest
3134
julia> using Distributions
3235
@@ -62,18 +65,26 @@ julia> # This also respects the context in `model`.
6265
6366
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
6467
true
68+
69+
julia> # If we also need to calculate the gradient, we can specify an AD backend.
70+
import ForwardDiff, ADTypes
71+
72+
julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
73+
74+
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
75+
(-2.3378770664093453, [1.0])
6576
```
6677
"""
6778
struct LogDensityFunction{
68-
V<:AbstractVarInfo,
6979
M<:Model,
70-
C<:Union{Nothing,AbstractContext},
80+
V<:AbstractVarInfo,
81+
C<:AbstractContext,
7182
AD<:Union{Nothing,ADTypes.AbstractADType},
7283
}
73-
"varinfo used for evaluation"
74-
varinfo::V
7584
"model used for evaluation"
7685
model::M
86+
"varinfo used for evaluation"
87+
varinfo::V
7788
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
7889
context::C
7990
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
@@ -86,7 +97,7 @@ struct LogDensityFunction{
8697
function LogDensityFunction(
8798
model::Model,
8899
varinfo::AbstractVarInfo=VarInfo(model),
89-
context::AbstractContext=DefaultContext();
100+
context::AbstractContext=leafcontext(model.context);
90101
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
91102
)
92103
if adtype === nothing
@@ -112,8 +123,8 @@ struct LogDensityFunction{
112123
end
113124
with_closure = with_closure
114125
end
115-
return new{typeof(varinfo),typeof(model),typeof(context),typeof(adtype)}(
116-
varinfo, model, context, adtype, prep, with_closure
126+
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
127+
model, varinfo, context, adtype, prep, with_closure
117128
)
118129
end
119130
end
@@ -161,21 +172,21 @@ end
161172
### LogDensityProblems interface
162173

163174
function LogDensityProblems.capabilities(
164-
::Type{<:LogDensityFunction{V,M,C,Nothing}}
165-
) where {V,M,C}
175+
::Type{<:LogDensityFunction{M,V,C,Nothing}}
176+
) where {M,V,C}
166177
return LogDensityProblems.LogDensityOrder{0}()
167178
end
168179
function LogDensityProblems.capabilities(
169-
::Type{<:LogDensityFunction{V,M,C,AD}}
170-
) where {V,M,C,AD<:ADTypes.AbstractADType}
180+
::Type{<:LogDensityFunction{M,V,C,AD}}
181+
) where {M,V,C,AD<:ADTypes.AbstractADType}
171182
return LogDensityProblems.LogDensityOrder{1}()
172183
end
173184
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
174185
return logdensity_at(x, f.model, f.varinfo, f.context)
175186
end
176187
function LogDensityProblems.logdensity_and_gradient(
177-
f::LogDensityFunction{V,M,C,AD}, x::AbstractVector
178-
) where {V,M,C,AD<:ADTypes.AbstractADType}
188+
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
189+
) where {M,V,C,AD<:ADTypes.AbstractADType}
179190
f.prep === nothing &&
180191
error("Gradient preparation not available; this should not happen")
181192
x = map(identity, x) # Concretise type
@@ -231,11 +242,6 @@ use_closure(::ADTypes.AutoForwardDiff) = false
231242
use_closure(::ADTypes.AutoMooncake) = false
232243
use_closure(::ADTypes.AutoReverseDiff) = true
233244

234-
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
235-
function getcontext(f::LogDensityFunction)
236-
return f.context === nothing ? leafcontext(f.model.context) : f.context
237-
end
238-
239245
"""
240246
getmodel(f)
241247
@@ -249,7 +255,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
249255
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
250256
"""
251257
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
252-
return Accessors.@set f.model = model
258+
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
253259
end
254260

255261
"""

0 commit comments

Comments
 (0)