Skip to content

Commit 656a757

Browse files
committed
avoid recording prior components on leaf-prior-context
and avoid recording likelihoods when invoked with leaf-Likelihood context
1 parent d9945d7 commit 656a757

File tree

4 files changed

+45
-27
lines changed

4 files changed

+45
-27
lines changed

src/pointwise_logdensities.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,26 @@ function Base.push!(
7373
return context.logdensities[vn] = logp
7474
end
7575

76+
77+
function _include_prior(context::PointwiseLogdensityContext)
78+
return leafcontext(context) isa Union{PriorContext,DefaultContext}
79+
end
80+
function _include_likelihood(context::PointwiseLogdensityContext)
81+
return leafcontext(context) isa Union{LikelihoodContext,DefaultContext}
82+
end
83+
84+
85+
7686
function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi)
7787
# Defer literal `observe` to child-context.
7888
return tilde_observe!!(context.context, right, left, vi)
7989
end
8090
function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
91+
# Completely defer to child context if we are not tracking likelihoods.
92+
if !(_include_likelihood(context))
93+
return tilde_observe!!(context.context, right, left, vn, vi)
94+
end
95+
8196
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
8297
# we have to intercept the call to `tilde_observe!`.
8398
logp, vi = tilde_observe(context.context, right, left, vi)
@@ -93,6 +108,11 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v
93108
return dot_tilde_observe!!(context.context, right, left, vi)
94109
end
95110
function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
111+
# Completely defer to child context if we are not tracking likelihoods.
112+
if !(_include_likelihood(context))
113+
return dot_tilde_observe!!(context.context, right, left, vn, vi)
114+
end
115+
96116
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
97117
# we have to intercept the call to `dot_tilde_observe!`.
98118

@@ -132,16 +152,19 @@ end
132152
function tilde_assume(context::PointwiseLogdensityContext, right::Distribution, vn, vi)
133153
#@info "PointwiseLogdensityContext tilde_assume called for $vn"
134154
value, logp, vi = tilde_assume(context.context, right, vn, vi)
135-
push!(context, vn, logp)
155+
if _include_prior(context)
156+
push!(context, vn, logp)
157+
end
136158
return value, logp, vi
137159
end
138160

139161
function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, vi)
140162
#@info "PointwiseLogdensityContext dot_tilde_assume called for $vns"
141163
value, logp, vi_new = dot_tilde_assume(context.context, right, left, vns, vi)
142-
# dispatch recording of log-densities based on type of right
143-
logps = record_dot_tilde_assume(context, right, left, vns, vi, logp)
144-
sum(logps) logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig")
164+
if _include_prior(context)
165+
logps = record_dot_tilde_assume(context, right, left, vns, vi, logp)
166+
sum(logps) logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig")
167+
end
145168
return value, logp, vi
146169
end
147170

@@ -172,7 +195,7 @@ end
172195
pointwise_logdensities(model::Model, chain::Chains, keytype = String)
173196
174197
Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
175-
with keys corresponding to symbols of the observations, and values being matrices
198+
with keys corresponding to symbols of the variables, and values being matrices
176199
of shape `(num_chains, num_samples)`.
177200
178201
`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.

src/test_utils.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,10 +1116,8 @@ function TestLogModifyingChildContext(
11161116
mod, context
11171117
)
11181118
end
1119-
# Samplers call leafcontext(model.context) when evaluating log-densities
1120-
# Hence, in order to be used need to say that its a leaf-context
1121-
#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
1122-
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf()
1119+
1120+
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
11231121
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
11241122
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
11251123
return TestLogModifyingChildContext(context.mod, child)

test/deprecated.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010

1111
# Compute the pointwise loglikelihoods.
1212
lls = pointwise_loglikelihoods(m, vi)
13-
loglikelihood = sum(sum, values(lls))
1413

1514
#if isempty(lls)
16-
if loglikelihood 0.0 #isempty(lls)
15+
if isempty(lls)
1716
# One of the models with literal observations, so we just skip.
18-
# TODO: Think of better way to detect this special case
1917
continue
2018
end
2119

20+
loglikelihood = sum(sum, values(lls))
2221
loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...)
2322

2423
#priors =

test/pointwise_logdensities.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
prior_context = PriorContext()
44
mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2)
55
mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx)
6-
#m = DynamicPPL.TestUtils.DEMO_MODELS[12]
6+
#m = DynamicPPL.TestUtils.DEMO_MODELS[1]
77
#m = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2()
8-
@testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS)
8+
demo_models = (
9+
DynamicPPL.TestUtils.DEMO_MODELS...,
10+
DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2())
11+
@testset "$(m.f)" for (i, m) in enumerate(demo_models)
912
#@show i
1013
example_values = DynamicPPL.TestUtils.rand_prior_true(m)
1114

@@ -26,23 +29,19 @@
2629
# Compute the pointwise loglikelihoods.
2730
lls = pointwise_logdensities(m, vi, likelihood_context)
2831
#lls2 = pointwise_loglikelihoods(m, vi)
29-
loglikelihood_sum = sum(sum, values(lls))
30-
if loglikelihood_sum 0.0 #isempty(lls)
32+
if isempty(lls)
3133
# One of the models with literal observations, so we just skip.
32-
# TODO: Think of better way to detect this special case
3334
loglikelihood_true = 0.0
35+
else
36+
loglikelihood_sum = sum(sum, values(lls))
37+
@test loglikelihood_sum loglikelihood_true
3438
end
35-
@test loglikelihood_sum loglikelihood_true
3639

3740
# Compute the pointwise logdensities of the priors.
3841
lps_prior = pointwise_logdensities(m, vi, prior_context)
3942
logp = sum(sum, values(lps_prior))
40-
if false # isempty(lps_prior)
41-
# One of the models with only observations so we just skip.
42-
else
43-
logp1 = getlogp(vi)
44-
@test !isfinite(logp_true) || logp logp_true
45-
end
43+
logp1 = getlogp(vi)
44+
@test !isfinite(logp_true) || logp logp_true
4645

4746
# Compute both likelihood and logdensity of prior
4847
# using the default DefaultContex
@@ -57,7 +56,6 @@
5756
end
5857
end
5958

60-
6159
@testset "pointwise_logdensities chain" begin
6260
@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
6361
s ~ InverseGamma(2, 3)
@@ -73,9 +71,9 @@ end
7371
# generate the sample used below
7472
chain = sample(model, MH(), MCMCThreads(), 10, 2)
7573
arr0 = stack(Array(chain, append_chains=false))
76-
@show(arr0);
74+
@show(arr0[1:2,:,:]);
7775
end
78-
arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939]
76+
arr0[1:2, :, :] = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497]
7977
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]);
8078
tmp1 = pointwise_logdensities(model, chain)
8179
vi = VarInfo(model)

0 commit comments

Comments
 (0)