Skip to content

Commit 1f1ec85

Browse files
committed
Give LogDensityFunction the getlogdensity field
1 parent d4ef1f2 commit 1f1ec85

File tree

5 files changed

+74
-43
lines changed

5 files changed

+74
-43
lines changed

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8787
vi = DynamicPPL.link(vi, model)
8888
end
8989

90-
f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend)
90+
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi, context; adtype=adbackend)
9191
# The parameters at which we evaluate f.
9292
θ = vi[:]
9393

src/logdensityfunction.jl

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1717
"""
1818
LogDensityFunction(
1919
model::Model,
20-
varinfo::AbstractVarInfo=VarInfo(model),
20+
getlogdensity::Function=getlogjoint,
21+
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity),
2122
context::AbstractContext=DefaultContext();
2223
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
2324
)
@@ -28,10 +29,10 @@ A struct which contains a model, along with all the information necessary to:
2829
- and if `adtype` is provided, calculate the gradient of the log density at
2930
that point.
3031
31-
At its most basic level, a LogDensityFunction wraps the model together with its
32-
the type of varinfo to be used, as well as the evaluation context. These must
33-
be known in order to calculate the log density (using
34-
[`DynamicPPL.evaluate!!`](@ref)).
32+
At its most basic level, a LogDensityFunction wraps the model together with
33+
the type of varinfo to be used, as well as the evaluation context and a function
34+
to extract the log density from the VarInfo. These must be known in order to
35+
calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)).
3536
3637
If the `adtype` keyword argument is provided, then this struct will also store
3738
the adtype along with other information for efficient calculation of the
@@ -73,13 +74,13 @@ julia> LogDensityProblems.dimension(f)
7374
1
7475
7576
julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
76-
f = LogDensityFunction(model, SimpleVarInfo(model));
77+
f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));
7778
7879
julia> LogDensityProblems.logdensity(f, [0.0])
7980
-2.3378770664093453
8081
81-
julia> # LogDensityFunction respects the accumulators in VarInfo:
82-
f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)));
82+
julia> # One can also specify evaluating e.g. the log prior only:
83+
f_prior = LogDensityFunction(model, getprior);
8384
8485
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
8586
true
@@ -94,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
9495
```
9596
"""
9697
struct LogDensityFunction{
97-
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
98+
M<:Model,F<:Function,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
9899
}
99100
"model used for evaluation"
100101
model::M
101-
"varinfo used for evaluation"
102+
"function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
103+
getlogdensity::F
104+
"varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
102105
varinfo::V
103106
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
104107
context::C
@@ -109,7 +112,8 @@ struct LogDensityFunction{
109112

110113
function LogDensityFunction(
111114
model::Model,
112-
varinfo::AbstractVarInfo=VarInfo(model),
115+
getlogdensity::Function=getlogjoint,
116+
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity),
113117
context::AbstractContext=leafcontext(model.context);
114118
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
115119
)
@@ -125,21 +129,22 @@ struct LogDensityFunction{
125129
x = map(identity, varinfo[:])
126130
if use_closure(adtype)
127131
prep = DI.prepare_gradient(
128-
x -> logdensity_at(x, model, varinfo, context), adtype, x
132+
x -> logdensity_at(x, model, getlogdensity, varinfo, context), adtype, x
129133
)
130134
else
131135
prep = DI.prepare_gradient(
132136
logdensity_at,
133137
adtype,
134138
x,
135139
DI.Constant(model),
140+
DI.Constant(getlogdensity),
136141
DI.Constant(varinfo),
137142
DI.Constant(context),
138143
)
139144
end
140145
end
141-
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
142-
model, varinfo, context, adtype, prep
146+
return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(context),typeof(adtype)}(
147+
model, getlogdensity, varinfo, context, adtype, prep
143148
)
144149
end
145150
end
@@ -164,64 +169,80 @@ function LogDensityFunction(
164169
end
165170
end
166171

172+
"""
173+
ldf_default_varinfo(model::Model, getlogdensity::Function)
174+
175+
Create the default AbstractVarInfo that should be used for evaluating the log density.
176+
177+
Only the accumulators necesessary for `getlogdensity` will be used.
178+
"""
179+
function ldf_default_varinfo(::Model, getlogdensity::Function)
180+
msg = """
181+
LogDensityFunction does not know what sort of VarInfo should be used when \
182+
`getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly.
183+
"""
184+
error(msg)
185+
end
186+
187+
ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model)
188+
189+
function ldf_default_varinfo(model::Model, ::typeof(getlogprior))
190+
return setaccs!!(VarInfo(model), (LogPriorAccumulator(),))
191+
end
192+
193+
function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood))
194+
return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),))
195+
end
196+
167197
"""
168198
logdensity_at(
169199
x::AbstractVector,
170200
model::Model,
201+
getlogdensity::Function,
171202
varinfo::AbstractVarInfo,
172203
context::AbstractContext
173204
)
174205
175206
Evaluate the log density of the given `model` at the given parameter values `x`,
176207
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
177208
only for its structure, in the sense that the parameters from the vector `x` are inserted
178-
into it, and its own parameters are discarded. It does, however, determine whether the log
179-
prior, likelihood, or joint is returned, based on which accumulators are set in it.
209+
into it, and its own parameters are discarded. `getlogdensity` is the function that extracts
210+
the log density from the evaluated varinfo.
180211
"""
181212
function logdensity_at(
182-
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
213+
x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo, context::AbstractContext
183214
)
184215
varinfo_new = unflatten(varinfo, x)
185216
varinfo_eval = last(evaluate!!(model, varinfo_new, context))
186-
has_prior = hasacc(varinfo_eval, Val(:LogPrior))
187-
has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood))
188-
if has_prior && has_likelihood
189-
return getlogjoint(varinfo_eval)
190-
elseif has_prior
191-
return getlogprior(varinfo_eval)
192-
elseif has_likelihood
193-
return getloglikelihood(varinfo_eval)
194-
else
195-
error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood")
196-
end
217+
return getlogdensity(varinfo_eval)
197218
end
198219

199220
### LogDensityProblems interface
200221

201222
function LogDensityProblems.capabilities(
202-
::Type{<:LogDensityFunction{M,V,C,Nothing}}
203-
) where {M,V,C}
223+
::Type{<:LogDensityFunction{M,F,V,C,Nothing}}
224+
) where {M,F,V,C}
204225
return LogDensityProblems.LogDensityOrder{0}()
205226
end
206227
function LogDensityProblems.capabilities(
207-
::Type{<:LogDensityFunction{M,V,C,AD}}
208-
) where {M,V,C,AD<:ADTypes.AbstractADType}
228+
::Type{<:LogDensityFunction{M,F,V,C,AD}}
229+
) where {M,F,V,C,AD<:ADTypes.AbstractADType}
209230
return LogDensityProblems.LogDensityOrder{1}()
210231
end
211232
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
212-
return logdensity_at(x, f.model, f.varinfo, f.context)
233+
return logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context)
213234
end
214235
function LogDensityProblems.logdensity_and_gradient(
215-
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
216-
) where {M,V,C,AD<:ADTypes.AbstractADType}
236+
f::LogDensityFunction{M,F,V,C,AD}, x::AbstractVector
237+
) where {M,F,V,C,AD<:ADTypes.AbstractADType}
217238
f.prep === nothing &&
218239
error("Gradient preparation not available; this should not happen")
219240
x = map(identity, x) # Concretise type
220241
# Make branching statically inferrable, i.e. type-stable (even if the two
221242
# branches happen to return different types)
222243
return if use_closure(f.adtype)
223244
DI.value_and_gradient(
224-
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
245+
x -> logdensity_at(x, f.model, f.getlogdensity, f.varinfo, f.context), f.prep, f.adtype, x
225246
)
226247
else
227248
DI.value_and_gradient(
@@ -230,6 +251,7 @@ function LogDensityProblems.logdensity_and_gradient(
230251
f.adtype,
231252
x,
232253
DI.Constant(f.model),
254+
DI.Constant(f.getlogdensity),
233255
DI.Constant(f.varinfo),
234256
DI.Constant(f.context),
235257
)
@@ -304,7 +326,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
304326
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
305327
"""
306328
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
307-
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
329+
return LogDensityFunction(model, f.getlogdensity, f.varinfo, f.context; adtype=f.adtype)
308330
end
309331

310332
"""

test/ad.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ using DynamicPPL: LogDensityFunction
2424

2525
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
2626
linked_varinfo = DynamicPPL.link(varinfo, m)
27-
f = LogDensityFunction(m, linked_varinfo)
27+
f = LogDensityFunction(m, getlogjoint, linked_varinfo)
2828
x = DynamicPPL.getparams(f)
2929
# Calculate reference logp + gradient of logp using ForwardDiff
30-
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
30+
ref_ldf = LogDensityFunction(m, getlogjoint, linked_varinfo; adtype=ref_adtype)
3131
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
3232

3333
@testset "$adtype" for adtype in test_adtypes
@@ -106,7 +106,7 @@ using DynamicPPL: LogDensityFunction
106106
spl = Sampler(MyEmptyAlg())
107107
vi = VarInfo(model)
108108
ldf = LogDensityFunction(
109-
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
109+
model, getlogjoint, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
110110
)
111111
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
112112
end

test/logdensityfunction.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,17 @@ end
1515
vns = DynamicPPL.TestUtils.varnames(model)
1616
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
1717

18+
vi = first(varinfos)
19+
theta = vi[:]
20+
ldf_joint = DynamicPPL.LogDensityFunction(model)
21+
@test LogDensityProblems.logdensity(ldf_joint, theta) logjoint(model, vi)
22+
ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior)
23+
@test LogDensityProblems.logdensity(ldf_prior, theta) logprior(model, vi)
24+
ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood)
25+
@test LogDensityProblems.logdensity(ldf_likelihood, theta) loglikelihood(model, vi)
26+
1827
@testset "$(varinfo)" for varinfo in varinfos
19-
logdensity = DynamicPPL.LogDensityFunction(model, varinfo)
28+
logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo)
2029
θ = varinfo[:]
2130
@test LogDensityProblems.logdensity(logdensity, θ) logjoint(model, varinfo)
2231
@test LogDensityProblems.dimension(logdensity) == length(θ)

test/test_util.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function test_model_ad(model, logp_manual)
1414
x = vi[:]
1515

1616
# Log probabilities using the model.
17-
= DynamicPPL.LogDensityFunction(model, vi)
17+
= DynamicPPL.LogDensityFunction(model, getlogjoint, vi)
1818
logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ)
1919

2020
# Check that both functions return the same values.

0 commit comments

Comments
 (0)