@@ -19,6 +19,13 @@ function setchildcontext(context::PointwiseLogdensityContext, child)
1919 return PointwiseLogdensityContext (context. logdensities, child)
2020end
2121
22+ function _include_prior (context:: PointwiseLogdensityContext )
23+ return leafcontext (context) isa Union{PriorContext,DefaultContext}
24+ end
25+ function _include_likelihood (context:: PointwiseLogdensityContext )
26+ return leafcontext (context) isa Union{LikelihoodContext,DefaultContext}
27+ end
28+
2229function Base. push! (
2330 context:: PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}} ,
2431 vn:: VarName ,
@@ -78,6 +85,11 @@ function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi)
7885 return tilde_observe!! (context. context, right, left, vi)
7986end
8087function tilde_observe!! (context:: PointwiseLogdensityContext , right, left, vn, vi)
88+ # Completely defer to child context if we are not tracking likelihoods.
89+ if ! (_include_likelihood (context))
90+ return tilde_observe!! (context. context, right, left, vn, vi)
91+ end
92+
8193 # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
8294 # we have to intercept the call to `tilde_observe!`.
8395 logp, vi = tilde_observe (context. context, right, left, vi)
@@ -93,6 +105,11 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v
93105 return dot_tilde_observe!! (context. context, right, left, vi)
94106end
95107function dot_tilde_observe!! (context:: PointwiseLogdensityContext , right, left, vn, vi)
108+ # Completely defer to child context if we are not tracking likelihoods.
109+ if ! (_include_likelihood (context))
110+ return dot_tilde_observe!! (context. context, right, left, vn, vi)
111+ end
112+
96113 # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
97114 # we have to intercept the call to `dot_tilde_observe!`.
98115
@@ -130,6 +147,10 @@ function _pointwise_tilde_observe(
130147end
131148
132149function tilde_assume!! (context:: PointwiseLogdensityContext , right, vn, vi)
150+ # Completely defer to child context if we are not tracking prior densities.
151+ _include_prior (context) || return tilde_assume!! (context. context, right, vn, vi)
152+
153+ # Otherwise, capture the return values.
133154 value, logp, vi = tilde_assume (context. context, right, vn, vi)
134155 # Track loglikelihood value.
135156 push! (context, vn, logp)
@@ -138,6 +159,11 @@ function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi)
138159end
139160
140161function dot_tilde_assume!! (context:: PointwiseLogdensityContext , right, left, vns, vi)
162+ # Completely defer to child context if we are not tracking prior densities.
163+ if ! (_include_prior (context))
164+ return dot_tilde_assume!! (context. context, right, left, vns, vi)
165+ end
166+
141167 value, logps = _pointwise_tilde_assume (context, right, left, vns, vi)
142168 # Track loglikelihood values.
143169 for (vn, logp) in zip (vns, logps)
173199 pointwise_logdensities(model::Model, chain::Chains, keytype = String)
174200
175201Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
176- with keys corresponding to symbols of the observations , and values being matrices
202+ with keys corresponding to symbols of the variables , and values being matrices
177203of shape `(num_chains, num_samples)`.
178204
179205`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
@@ -268,7 +294,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])],
268294
269295"""
270296function pointwise_logdensities (
271- model:: Model , chain, context:: AbstractContext = DefaultContext (), keytype :: Type{T} = String
297+ model:: Model , chain, context:: AbstractContext = DefaultContext ()
272298) where {T}
273299 # Get the data by executing the model once
274300 vi = VarInfo (model)
@@ -301,3 +327,63 @@ function pointwise_logdensities(
301327 model (varinfo, point_context)
302328 return point_context. logdensities
303329end
330+
331+ """
332+ pointwise_loglikelihoods(model, chain[, context])
333+
334+ Compute the pointwise log-likelihoods of the model given the chain.
335+
336+ This is the same as `pointwise_logdensities(model, chain, context)`, but only
337+ including the likelihood terms.
338+
339+ See also: [`pointwise_logdensities`](@ref).
340+ """
341+ function pointwise_loglikelihoods (
342+ model:: Model , chain, context:: AbstractContext = LikelihoodContext ()
343+ ) where {T}
344+ if ! (leafcontext (context) isa LikelihoodContext)
345+ throw (ArgumentError (" Leaf context should be a LikelihoodContext" ))
346+ end
347+
348+ return pointwise_logdensities (model, chain, context)
349+ end
350+
351+ function pointwise_loglikelihoods (
352+ model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext = LikelihoodContext ()
353+ ) where {T}
354+ if ! (leafcontext (context) isa LikelihoodContext)
355+ throw (ArgumentError (" Leaf context should be a LikelihoodContext" ))
356+ end
357+
358+ return pointwise_logdensities (model, chain, context)
359+ end
360+
361+ """
362+ pointwise_prior_logdensities(model, chain[, context])
363+
364+ Compute the pointwise log-prior-densities of the model given the chain.
365+
366+ This is the same as `pointwise_logdensities(model, chain, context)`, but only
367+ including the prior terms.
368+
369+ See also: [`pointwise_logdensities`](@ref).
370+ """
371+ function pointwise_prior_logdensities (
372+ model:: Model , chain, context:: AbstractContext = PriorContext ()
373+ ) where {T}
374+ if ! (leafcontext (context) isa PriorContext)
375+ throw (ArgumentError (" Leaf context should be a PriorContext" ))
376+ end
377+
378+ return pointwise_logdensities (model, chain, context)
379+ end
380+
381+ function pointwise_prior_logdensities (
382+ model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext = PriorContext ()
383+ ) where {T}
384+ if ! (leafcontext (context) isa PriorContext)
385+ throw (ArgumentError (" Leaf context should be a PriorContext" ))
386+ end
387+
388+ return pointwise_logdensities (model, chain, context)
389+ end
0 commit comments