Skip to content

Commit 49ad8b0

Browse files
committed
Added back pointwise_loglikelihoods and a new function
`pointwise_prior_logdensities` + a mechanism to determine what we should include in the resulting dictionary based on the leaf context
1 parent 3d3a97e commit 49ad8b0

File tree

1 file changed

+88
-2
lines changed

1 file changed

+88
-2
lines changed

src/pointwise_logdensities.jl

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ function setchildcontext(context::PointwiseLogdensityContext, child)
1919
return PointwiseLogdensityContext(context.logdensities, child)
2020
end
2121

22+
function _include_prior(context::PointwiseLogdensityContext)
23+
return leafcontext(context) isa Union{PriorContext,DefaultContext}
24+
end
25+
function _include_likelihood(context::PointwiseLogdensityContext)
26+
return leafcontext(context) isa Union{LikelihoodContext,DefaultContext}
27+
end
28+
2229
function Base.push!(
2330
context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}},
2431
vn::VarName,
@@ -78,6 +85,11 @@ function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi)
7885
return tilde_observe!!(context.context, right, left, vi)
7986
end
8087
function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
88+
# Completely defer to child context if we are not tracking likelihoods.
89+
if !(_include_likelihood(context))
90+
return tilde_observe!!(context.context, right, left, vn, vi)
91+
end
92+
8193
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
8294
# we have to intercept the call to `tilde_observe!`.
8395
logp, vi = tilde_observe(context.context, right, left, vi)
@@ -93,6 +105,11 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v
93105
return dot_tilde_observe!!(context.context, right, left, vi)
94106
end
95107
function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
108+
# Completely defer to child context if we are not tracking likelihoods.
109+
if !(_include_likelihood(context))
110+
return dot_tilde_observe!!(context.context, right, left, vn, vi)
111+
end
112+
96113
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
97114
# we have to intercept the call to `dot_tilde_observe!`.
98115

@@ -130,6 +147,10 @@ function _pointwise_tilde_observe(
130147
end
131148

132149
function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi)
150+
# Completely defer to child context if we are not tracking prior densities.
151+
_include_prior(context) || return tilde_assume!!(context.context, right, vn, vi)
152+
153+
# Otherwise, capture the return values.
133154
value, logp, vi = tilde_assume(context.context, right, vn, vi)
134155
# Track loglikelihood value.
135156
push!(context, vn, logp)
@@ -138,6 +159,11 @@ function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi)
138159
end
139160

140161
function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi)
162+
# Completely defer to child context if we are not tracking prior densities.
163+
if !(_include_prior(context))
164+
return dot_tilde_assume!!(context.context, right, left, vns, vi)
165+
end
166+
141167
value, logps = _pointwise_tilde_assume(context, right, left, vns, vi)
142168
# Track loglikelihood values.
143169
for (vn, logp) in zip(vns, logps)
@@ -173,7 +199,7 @@ end
173199
pointwise_logdensities(model::Model, chain::Chains, keytype = String)
174200
175201
Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
176-
with keys corresponding to symbols of the observations, and values being matrices
202+
with keys corresponding to symbols of the variables, and values being matrices
177203
of shape `(num_chains, num_samples)`.
178204
179205
`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
@@ -268,7 +294,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])],
268294
269295
"""
270296
function pointwise_logdensities(
271-
model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String
297+
model::Model, chain, context::AbstractContext=DefaultContext()
272298
) where {T}
273299
# Get the data by executing the model once
274300
vi = VarInfo(model)
@@ -301,3 +327,63 @@ function pointwise_logdensities(
301327
model(varinfo, point_context)
302328
return point_context.logdensities
303329
end
330+
331+
"""
332+
pointwise_loglikelihoods(model, chain[, context])
333+
334+
Compute the pointwise log-likelihoods of the model given the chain.
335+
336+
This is the same as `pointwise_logdensities(model, chain, context)`, but only
337+
including the likelihood terms.
338+
339+
See also: [`pointwise_logdensities`](@ref).
340+
"""
341+
function pointwise_loglikelihoods(
342+
model::Model, chain, context::AbstractContext=LikelihoodContext()
343+
) where {T}
344+
if !(leafcontext(context) isa LikelihoodContext)
345+
throw(ArgumentError("Leaf context should be a LikelihoodContext"))
346+
end
347+
348+
return pointwise_logdensities(model, chain, context)
349+
end
350+
351+
function pointwise_loglikelihoods(
352+
model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext()
353+
) where {T}
354+
if !(leafcontext(context) isa LikelihoodContext)
355+
throw(ArgumentError("Leaf context should be a LikelihoodContext"))
356+
end
357+
358+
return pointwise_logdensities(model, chain, context)
359+
end
360+
361+
"""
362+
pointwise_prior_logdensities(model, chain[, context])
363+
364+
Compute the pointwise log-prior-densities of the model given the chain.
365+
366+
This is the same as `pointwise_logdensities(model, chain, context)`, but only
367+
including the prior terms.
368+
369+
See also: [`pointwise_logdensities`](@ref).
370+
"""
371+
function pointwise_prior_logdensities(
372+
model::Model, chain, context::AbstractContext=PriorContext()
373+
) where {T}
374+
if !(leafcontext(context) isa PriorContext)
375+
throw(ArgumentError("Leaf context should be a PriorContext"))
376+
end
377+
378+
return pointwise_logdensities(model, chain, context)
379+
end
380+
381+
function pointwise_prior_logdensities(
382+
model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext()
383+
) where {T}
384+
if !(leafcontext(context) isa PriorContext)
385+
throw(ArgumentError("Leaf context should be a PriorContext"))
386+
end
387+
388+
return pointwise_logdensities(model, chain, context)
389+
end

0 commit comments

Comments
 (0)