Skip to content

Commit 18beb57

Browse files
committed
record single prior components
by forwarding dot_tilde_assume to tilde_assume
1 parent 5842656 commit 18beb57

File tree

3 files changed

+35
-23
lines changed

3 files changed

+35
-23
lines changed

src/pointwise_logdensities.jl

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -132,37 +132,49 @@ end
132132
function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi)
133133
#@info "PointwiseLogdensityContext tilde_assume!! called for $vn"
134134
value, logp, vi = tilde_assume(context.context, right, vn, vi)
135-
#sym = DynamicPPL.getsym(vn)
136-
new_context = acc_logp!(context, vn, logp)
135+
push!(context, vn, logp)
137136
return value, logp, vi
138137
end
139138

140-
function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vn, vi)
141-
#@info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn"
142-
# @show vn, left, right, typeof(context).name
143-
value, logp, vi = dot_tilde_assume(context.context, right, left, vn, vi)
144-
new_context = acc_logp!(context, vn, logp)
139+
function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, vi)
140+
#@info "PointwiseLogdensityContext dot_tilde_assume called for $vns"
141+
value, logp, vi_new = dot_tilde_assume(context.context, right, left, vns, vi)
142+
# dispatch recording of log-densities based on type of right
143+
logps = record_dot_tilde_assume(context, right, left, vns, vi, logp)
144+
sum(logps) logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig")
145145
return value, logp, vi
146146
end
147147

148-
function acc_logp!(context::PointwiseLogdensityContext, vn::VarName, logp)
149-
push!(context, vn, logp)
150-
return (context)
148+
function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::UnivariateDistribution, left, vns, vi, logp)
149+
# forward to tilde_assume for each variable
150+
map(vns) do vn
151+
value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi)
152+
logp_i
153+
end
151154
end
152155

153-
function acc_logp!(context::PointwiseLogdensityContext, vns::AbstractVector{<:VarName}, logp)
154-
# construct a new VarName from given sequence of VarName
155-
# assume that all items in vns have an IndexLens optic
156-
indices = tuplejoin(map(vn -> getoptic(vn).indices, vns)...)
157-
vn = VarName(first(vns), Accessors.IndexLens(indices))
158-
push!(context, vn, logp)
159-
return (context)
156+
function record_dot_tilde_assume(context::PointwiseLogdensityContext, rights::AbstractVector{<:Distribution}, left, vns, vi, logp)
157+
# forward to tilde_assume for each variable and distribution
158+
logps = map(vns, rights) do vn, right
159+
# use current context to record vn
160+
value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi)
161+
logp_i
162+
end
160163
end
161164

162-
#https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8
163-
@inline tuplejoin(x) = x
164-
@inline tuplejoin(x, y) = (x..., y...)
165-
@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...)
165+
function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::MultivariateDistribution, left, vns, vi, logp)
166+
#@info "PointwiseLogdensityContext record_dot_tilde_assume multivariate called for $vns"
167+
# For multivariate distribution on the right there is only a single density.
168+
# Need to construct a combined VarName.
169+
# Assume that all vns have an IndexLens with a Colon at the first position
170+
# and a single number at the second position.
171+
indices = map(vn -> getoptic(vn).indices[2], vns)
172+
indices_combined = (:,indices)
173+
#indices = tuplejoin(map(vn -> getoptic(vn).indices[2], vns)...)
174+
vn = VarName(first(vns), Accessors.IndexLens(indices_combined))
175+
push!(context, vn, logp)
176+
return logp
177+
end
166178

167179
() -> begin
168180
# code that generates julia-repl in docstring below

test/pointwise_logdensitiesjl renamed to test/pointwise_logdensities.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
prior_context = PriorContext()
44
mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2)
55
mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx)
6-
#m = DynamicPPL.TestUtils.DEMO_MODELS[1]
6+
#m = DynamicPPL.TestUtils.DEMO_MODELS[12]
77
@testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS)
88
#@show i
99
example_values = DynamicPPL.TestUtils.rand_prior_true(m)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ include("test_util.jl")
5757

5858
include("serialization.jl")
5959

60-
include("pointwise_logdensitiesjl")
60+
include("pointwise_logdensities.jl")
6161

6262
include("lkj.jl")
6363

0 commit comments

Comments
 (0)