Skip to content

Commit d54bfdc

Browse files
committed
Replaced acc_logp! in favour of something similar to the
`_pointwise_tilde_observe` method
1 parent 5842656 commit d54bfdc

File tree

1 file changed

+38
-38
lines changed

1 file changed

+38
-38
lines changed

src/pointwise_logdensities.jl

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -129,45 +129,44 @@ function _pointwise_tilde_observe(
129129
end
130130
end
131131

132-
function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi)
133-
#@info "PointwiseLogdensityContext tilde_assume!! called for $vn"
132+
function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi)
134133
value, logp, vi = tilde_assume(context.context, right, vn, vi)
135-
#sym = DynamicPPL.getsym(vn)
136-
new_context = acc_logp!(context, vn, logp)
137-
return value, logp, vi
138-
end
134+
# Track loglikelihood value.
135+
push!(context, vn, logp)
139136

140-
function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vn, vi)
141-
#@info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn"
142-
# @show vn, left, right, typeof(context).name
143-
value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi)
144-
new_context = acc_logp!(context, vn, logp)
145-
return value, logp, vi
137+
return value, acclogp!!(vi, logp)
146138
end
147139

148-
function acc_logp!(context::PointwiseLogdensityContext, vn::VarName, logp)
149-
push!(context, vn, logp)
150-
return (context)
140+
function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi)
141+
value, logps = _pointwise_tilde_assume(context, right, left, vns, vi)
142+
# Track loglikelihood values.
143+
for (vn, logp) in zip(vns, logps)
144+
push!(context, vn, logp)
145+
end
146+
return value, acclogp!!(vi, sum(logps))
151147
end
152148

153-
function acc_logp!(context::PointwiseLogdensityContext, vns::AbstractVector{<:VarName}, logp)
154-
# construct a new VarName from given sequence of VarName
155-
# assume that all items in vns have an IndexLens optic
156-
indices = tuplejoin(map(vn -> getoptic(vn).indices, vns)...)
157-
vn = VarName(first(vns), Accessors.IndexLens(indices))
158-
push!(context, vn, logp)
159-
return (context)
149+
function _pointwise_tilde_assume(context, right, left, vns, vi)
150+
# We need to drop the `vi` returned.
151+
values_and_logps = broadcast(right, left, vns) do r, l, vn
152+
val, logp, _ = tilde_assume(context, r, vn, vi)
153+
return val, logp
154+
end
155+
return map(first, values_and_logps), map(last, values_and_logps)
160156
end
161157

162-
#https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8
163-
@inline tuplejoin(x) = x
164-
@inline tuplejoin(x, y) = (x..., y...)
165-
@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...)
166-
167-
() -> begin
168-
# code that generates julia-repl in docstring below
169-
# using DynamicPPL, Turing
170-
# TODO when Turing version that is compatible with DynamicPPL 0.29 becomes available
158+
function _pointwise_tilde_assume(
159+
context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi
160+
)
161+
# We need to drop the `vi` returned.
162+
values_and_logps = map(eachcol(left), vns) do l, vn
163+
val, logp, _ = tilde_assume(context, right, vn, vi)
164+
return val, logp
165+
end
166+
# HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent.
167+
# But this also means that we need to first flatten the entire `values` component before recombining.
168+
values = recombine(right, mapreduce(vec first, vcat, values_and_logps), length(vns))
169+
return values, map(last, values_and_logps)
171170
end
172171

173172
"""
@@ -268,8 +267,9 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])],
268267
```
269268
270269
"""
271-
function pointwise_logdensities(model::Model, chain,
272-
context::AbstractContext=DefaultContext(), keytype::Type{T}=String) where {T}
270+
function pointwise_logdensities(
271+
model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String
272+
) where {T}
273273
# Get the data by executing the model once
274274
vi = VarInfo(model)
275275
point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context)
@@ -292,12 +292,12 @@ function pointwise_logdensities(model::Model, chain,
292292
return logdensities
293293
end
294294

295-
function pointwise_logdensities(model::Model,
296-
varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext())
295+
function pointwise_logdensities(
296+
model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()
297+
)
297298
point_context = PointwiseLogdensityContext(
298-
OrderedDict{VarName,Vector{Float64}}(), context)
299+
OrderedDict{VarName,Vector{Float64}}(), context
300+
)
299301
model(varinfo, point_context)
300302
return point_context.logdensities
301303
end
302-
303-

0 commit comments

Comments
 (0)