@@ -129,45 +129,44 @@ function _pointwise_tilde_observe(
129129 end
130130end
131131
132- function tilde_assume (context:: PointwiseLogdensityContext , right, vn, vi)
133- # @info "PointwiseLogdensityContext tilde_assume!! called for $vn"
132+ function tilde_assume!! (context:: PointwiseLogdensityContext , right, vn, vi)
134133 value, logp, vi = tilde_assume (context. context, right, vn, vi)
135- # sym = DynamicPPL.getsym(vn)
136- new_context = acc_logp! (context, vn, logp)
137- return value, logp, vi
138- end
134+ # Track loglikelihood value.
135+ push! (context, vn, logp)
139136
140- function dot_tilde_assume (context:: PointwiseLogdensityContext , right, left, vn, vi)
141- # @info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn"
142- # @show vn, left, right, typeof(context).name
143- value, logp, vi = dot_tilde_assume (context. context, right, left, vn, vi)
144- new_context = acc_logp! (context, vn, logp)
145- return value, logp, vi
137+ return value, acclogp!! (vi, logp)
146138end
147139
148- function acc_logp! (context:: PointwiseLogdensityContext , vn:: VarName , logp)
149- push! (context, vn, logp)
150- return (context)
140+ function dot_tilde_assume!! (context:: PointwiseLogdensityContext , right, left, vns, vi)
141+ value, logps = _pointwise_tilde_assume (context, right, left, vns, vi)
142+ # Track loglikelihood values.
143+ for (vn, logp) in zip (vns, logps)
144+ push! (context, vn, logp)
145+ end
146+ return value, acclogp!! (vi, sum (logps))
151147end
152148
153- function acc_logp! (context:: PointwiseLogdensityContext , vns:: AbstractVector{<:VarName} , logp )
154- # construct a new VarName from given sequence of VarName
155- # assume that all items in vns have an IndexLens optic
156- indices = tuplejoin ( map (vn -> getoptic (vn) . indices, vns) ... )
157- vn = VarName ( first (vns), Accessors . IndexLens (indices))
158- push! (context, vn, logp)
159- return (context )
149+ function _pointwise_tilde_assume (context, right, left, vns, vi )
150+ # We need to drop the `vi` returned.
151+ values_and_logps = broadcast (right, left, vns) do r, l, vn
152+ val, logp, _ = tilde_assume (context, r, vn, vi )
153+ return val, logp
154+ end
155+ return map (first, values_and_logps), map (last, values_and_logps )
160156end
161157
162- # https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8
163- @inline tuplejoin (x) = x
164- @inline tuplejoin (x, y) = (x... , y... )
165- @inline tuplejoin (x, y, z... ) = (x... , tuplejoin (y, z... )... )
166-
167- () -> begin
168- # code that generates julia-repl in docstring below
169- # using DynamicPPL, Turing
170- # TODO when Turing version that is compatible with DynamicPPL 0.29 becomes available
158+ function _pointwise_tilde_assume (
159+ context, right:: MultivariateDistribution , left:: AbstractMatrix , vns, vi
160+ )
161+ # We need to drop the `vi` returned.
162+ values_and_logps = map (eachcol (left), vns) do l, vn
163+ val, logp, _ = tilde_assume (context, right, vn, vi)
164+ return val, logp
165+ end
166+ # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent.
167+ # But this also means that we need to first flatten the entire `values` component before recombining.
168+ values = recombine (right, mapreduce (vec ∘ first, vcat, values_and_logps), length (vns))
169+ return values, map (last, values_and_logps)
171170end
172171
173172"""
@@ -268,8 +267,9 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])],
268267```
269268
270269"""
271- function pointwise_logdensities (model:: Model , chain,
272- context:: AbstractContext = DefaultContext (), keytype:: Type{T} = String) where {T}
270+ function pointwise_logdensities (
271+ model:: Model , chain, context:: AbstractContext = DefaultContext (), keytype:: Type{T} = String
272+ ) where {T}
273273 # Get the data by executing the model once
274274 vi = VarInfo (model)
275275 point_context = PointwiseLogdensityContext (OrderedDict {T,Vector{Float64}} (), context)
@@ -292,12 +292,12 @@ function pointwise_logdensities(model::Model, chain,
292292 return logdensities
293293end
294294
295- function pointwise_logdensities (model:: Model ,
296- varinfo:: AbstractVarInfo , context:: AbstractContext = DefaultContext ())
295+ function pointwise_logdensities (
296+ model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext = DefaultContext ()
297+ )
297298 point_context = PointwiseLogdensityContext (
298- OrderedDict {VarName,Vector{Float64}} (), context)
299+ OrderedDict {VarName,Vector{Float64}} (), context
300+ )
299301 model (varinfo, point_context)
300302 return point_context. logdensities
301303end
302-
303-
0 commit comments