Skip to content

Commit d9945d7

Browse files
committed
forward dot_tilde_assume to tilde_assume for Multivariate
1 parent 18beb57 commit d9945d7

File tree

3 files changed

+66
-22
lines changed

3 files changed

+66
-22
lines changed

src/pointwise_logdensities.jl

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v
9898

9999
# We want to treat `.~` as a collection of independent observations,
100100
# hence we need the `logp` for each of them. Broadcasting the univariate
101-
# `tilde_obseve` does exactly this.
101+
# `tilde_observe` does exactly this.
102102
logps = _pointwise_tilde_observe(context.context, right, left, vi)
103103

104104
# Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`.
@@ -129,8 +129,8 @@ function _pointwise_tilde_observe(
129129
end
130130
end
131131

132-
function tilde_assume(context::PointwiseLogdensityContext, right, vn, vi)
133-
#@info "PointwiseLogdensityContext tilde_assume!! called for $vn"
132+
function tilde_assume(context::PointwiseLogdensityContext, right::Distribution, vn, vi)
133+
#@info "PointwiseLogdensityContext tilde_assume called for $vn"
134134
value, logp, vi = tilde_assume(context.context, right, vn, vi)
135135
push!(context, vn, logp)
136136
return value, logp, vi
@@ -145,7 +145,7 @@ function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns,
145145
return value, logp, vi
146146
end
147147

148-
function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::UnivariateDistribution, left, vns, vi, logp)
148+
function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::Distribution, left, vns, vi, logp)
149149
# forward to tilde_assume for each variable
150150
map(vns) do vn
151151
value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi)
@@ -155,27 +155,13 @@ end
155155

156156
function record_dot_tilde_assume(context::PointwiseLogdensityContext, rights::AbstractVector{<:Distribution}, left, vns, vi, logp)
157157
# forward to tilde_assume for each variable and distribution
158-
logps = map(vns, rights) do vn, right
158+
map(vns, rights) do vn, right
159159
# use current context to record vn
160160
value_i, logp_i, vi_i = tilde_assume(context, right, vn, vi)
161161
logp_i
162162
end
163163
end
164164

165-
function record_dot_tilde_assume(context::PointwiseLogdensityContext, right::MultivariateDistribution, left, vns, vi, logp)
166-
#@info "PointwiseLogdensityContext record_dot_tilde_assume multivariate called for $vns"
167-
# For multivariate distribution on the right there is only a single density.
168-
# Need to construct a combined VarName.
169-
# Assume that all vns have an IndexLens with a Colon at the first position
170-
# and a single number at the second position.
171-
indices = map(vn -> getoptic(vn).indices[2], vns)
172-
indices_combined = (:,indices)
173-
#indices = tuplejoin(map(vn -> getoptic(vn).indices[2], vns)...)
174-
vn = VarName(first(vns), Accessors.IndexLens(indices_combined))
175-
push!(context, vn, logp)
176-
return logp
177-
end
178-
179165
() -> begin
180166
# code that generates julia-repl in docstring below
181167
# using DynamicPPL, Turing

src/test_utils.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,62 @@ function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)
667667
return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)]
668668
end
669669

670+
# model with truly two 3 columns in of d=2 (rather than ProductMatrix of single factor) to explore dot_tilde_assume with MultivariateDistribution
671+
@model function demo_dot_assume_matrix_dot_observe_matrix2(
672+
x=transpose([1.5 2.0;1.6 2.1;1.45 2.05]), ::Type{TV}=Array{Float64}
673+
) where {TV}
674+
d = size(x,1)
675+
n = size(x,2)
676+
s = TV(undef, d, n)
677+
# for i in 1:n
678+
# s[:,i] ~ product_distribution([InverseGamma(2, 3) for _ in 1:d])
679+
# end
680+
s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d])
681+
m = TV(undef, d, n)
682+
Sigma_x = Diagonal(s[:,1])
683+
for i in 1:n
684+
diag_s = Diagonal(s[:,i])
685+
m[:,i] ~ MvNormal(zeros(d), diag_s)
686+
x[:,i] ~ MvNormal(m[:,i], Sigma_x)
687+
end
688+
689+
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
690+
end
691+
function logprior_true(
692+
model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m
693+
)
694+
n = size(model.args.x,1)
695+
d = size(model.args.x,2)
696+
logd = map(1:d) do i_d
697+
s_vec = Diagonal(s[:,i_d])
698+
loglikelihood(InverseGamma(2, 3), s_vec) +
699+
logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m[:,i_d])
700+
end
701+
return sum(logd)
702+
end
703+
function loglikelihood_true(
704+
model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m
705+
)
706+
n = size(model.args.x,2)
707+
d = size(model.args.x,1)
708+
s_vec = Diagonal(s[:,1])
709+
logd = map(1:n) do i
710+
loglikelihood(MvNormal(m[:,i], Diagonal(s_vec)), model.args.x[:,i])
711+
end
712+
return sum(logd)
713+
end
714+
function logprior_true_with_logabsdet_jacobian(
715+
model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)}, s, m
716+
)
717+
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
718+
end
719+
function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)})
720+
s = m = zeros(2, 3) # used for varname concretization only
721+
return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(s[:, 3], true),
722+
@varname(m[:,1], true), @varname(m[:,2], true), @varname(m[:,3], true)]
723+
end
724+
725+
670726
@model function demo_assume_matrix_dot_observe_matrix(
671727
x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64}
672728
) where {TV}
@@ -748,6 +804,7 @@ const MultivariateAssumeDemoModels = Union{
748804
Model{typeof(demo_dot_assume_observe_submodel)},
749805
Model{typeof(demo_dot_assume_dot_observe_matrix)},
750806
Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)},
807+
Model{typeof(demo_dot_assume_matrix_dot_observe_matrix2)},
751808
}
752809
function posterior_mean(model::MultivariateAssumeDemoModels)
753810
# Get some containers to fill.

test/pointwise_logdensities.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2)
55
mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx)
66
#m = DynamicPPL.TestUtils.DEMO_MODELS[12]
7+
#m = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2()
78
@testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS)
89
#@show i
910
example_values = DynamicPPL.TestUtils.rand_prior_true(m)
@@ -25,13 +26,13 @@
2526
# Compute the pointwise loglikelihoods.
2627
lls = pointwise_logdensities(m, vi, likelihood_context)
2728
#lls2 = pointwise_loglikelihoods(m, vi)
28-
loglikelihood = sum(sum, values(lls))
29-
if loglikelihood 0.0 #isempty(lls)
29+
loglikelihood_sum = sum(sum, values(lls))
30+
if loglikelihood_sum 0.0 #isempty(lls)
3031
# One of the models with literal observations, so we just skip.
3132
# TODO: Think of better way to detect this special case
3233
loglikelihood_true = 0.0
3334
end
34-
@test loglikelihood loglikelihood_true
35+
@test loglikelihood_sum loglikelihood_true
3536

3637
# Compute the pointwise logdensities of the priors.
3738
lps_prior = pointwise_logdensities(m, vi, prior_context)

0 commit comments

Comments
 (0)