Skip to content

Commit fd8d3b2

Browse files
committed
docstring varwise_logpriores
use loop for prior in example Unfortunately cannot make it a jldoctest, because relies on Turing for sampling
1 parent 216d50c commit fd8d3b2

File tree

3 files changed

+77
-122
lines changed

3 files changed

+77
-122
lines changed

src/logpriors_var.jl

Lines changed: 48 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,8 @@ function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi)
4040
end
4141

4242

43-
function tilde_observe(context::VarwisePriorContext, right, left, vi)
44-
# Since we are evaluating the prior, the log probability of all the observations
45-
# is set to 0. This has the effect of ignoring the likelihood.
46-
return 0.0, vi
47-
#tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi)
48-
#return tmp
49-
end
43+
tilde_observe(context::VarwisePriorContext, right, left, vi) = 0, vi
44+
dot_tilde_observe(::VarwisePriorContext, right, left, vi) = 0, vi
5045

5146
function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVector{<:VarName}}, logp)
5247
#sym = DynamicPPL.getsym(vn) # leads to duplicates
@@ -56,105 +51,52 @@ function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVecto
5651
return (context)
5752
end
5853

59-
60-
# """
61-
# pointwise_logpriors(model::Model, chain::Chains, keytype = String)
62-
63-
# Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
64-
# with keys corresponding to symbols of the observations, and values being matrices
65-
# of shape `(num_chains, num_samples)`.
66-
67-
# `keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
68-
# Currently, only `String` and `VarName` are supported.
69-
70-
# # Notes
71-
# Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
72-
# both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an
73-
# *observation*) statements can be implemented in three ways:
74-
# 1. using a `for` loop:
75-
# ```julia
76-
# for i in eachindex(y)
77-
# y[i] ~ Normal(μ, σ)
78-
# end
79-
# ```
80-
# 2. using `.~`:
81-
# ```julia
82-
# y .~ Normal(μ, σ)
83-
# ```
84-
# 3. using `MvNormal`:
85-
# ```julia
86-
# y ~ MvNormal(fill(μ, n), σ^2 * I)
87-
# ```
88-
89-
# In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
90-
# while in (3) `y` will be treated as a _single_ n-dimensional observation.
91-
92-
# This is important to keep in mind, in particular if the computation is used
93-
# for downstream computations.
94-
95-
# # Examples
96-
# ## From chain
97-
# ```julia-repl
98-
# julia> using DynamicPPL, Turing
99-
100-
# julia> @model function demo(xs, y)
101-
# s ~ InverseGamma(2, 3)
102-
# m ~ Normal(0, √s)
103-
# for i in eachindex(xs)
104-
# xs[i] ~ Normal(m, √s)
105-
# end
106-
107-
# y ~ Normal(m, √s)
108-
# end
109-
# demo (generic function with 1 method)
110-
111-
# julia> model = demo(randn(3), randn());
112-
113-
# julia> chain = sample(model, MH(), 10);
114-
115-
# julia> pointwise_logpriors(model, chain)
116-
# OrderedDict{String,Array{Float64,2}} with 4 entries:
117-
# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
118-
# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
119-
# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
120-
# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
121-
122-
# julia> pointwise_logpriors(model, chain, String)
123-
# OrderedDict{String,Array{Float64,2}} with 4 entries:
124-
# "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
125-
# "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
126-
# "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
127-
# "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
128-
129-
# julia> pointwise_logpriors(model, chain, VarName)
130-
# OrderedDict{VarName,Array{Float64,2}} with 4 entries:
131-
# xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
132-
# xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
133-
# xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
134-
# y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
135-
# ```
136-
137-
# ## Broadcasting
138-
# Note that `x .~ Dist()` will treat `x` as a collection of
139-
# _independent_ observations rather than as a single observation.
140-
141-
# ```jldoctest; setup = :(using Distributions)
142-
# julia> @model function demo(x)
143-
# x .~ Normal()
144-
# end;
145-
146-
# julia> m = demo([1.0, ]);
147-
148-
# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])])
149-
# -1.4189385332046727
150-
151-
# julia> m = demo([1.0; 1.0]);
152-
153-
# julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
154-
# (-1.4189385332046727, -1.4189385332046727)
155-
# ```
156-
157-
# """
54+
"""
55+
varwise_logpriors(model::Model, chain::Chains; context)
56+
57+
Runs `model` on each sample in `chain` returning a tuple `(values, var_names)`
58+
with var_names corresponding to symbols of the prior components, and values being
59+
array of shape `(num_samples, num_components, num_chains)`.
60+
61+
`context` specifies child context that handles computation of log-priors.
62+
63+
# Example
64+
```julia; setup = :(using Distributions)
65+
using DynamicPPL, Turing
66+
67+
@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
68+
s ~ InverseGamma(2, 3)
69+
m = TV(undef, length(x))
70+
for i in eachindex(x)
71+
m[i] ~ Normal(0, √s)
72+
end
73+
x ~ MvNormal(m, √s)
74+
end
75+
76+
model = demo(randn(3), randn());
77+
78+
chain = sample(model, MH(), 10);
79+
80+
lp = varwise_logpriors(model, chain)
81+
# Can be used to construct a new Chains object
82+
#lpc = MCMCChains(varwise_logpriors(model, chain)...)
83+
84+
# got a logdensity for each parameter prior
85+
(but fewer if used `.~` assignments, see below)
86+
lp[2] == names(chain, :parameters)
87+
88+
# for each sample in the Chains object
89+
size(lp[1])[[1,3]] == size(chain)[[1,3]]
90+
```
91+
92+
# Broadcasting
93+
Note that `m .~ Dist()` will treat `m` as a collection of
94+
_independent_ prior rather than as a single prior,
95+
but `varwise_logpriors` returns the single
96+
sum of log-likelihood of components of `m` only.
97+
If one needs the log-density of the components, one needs to rewrite
98+
the model with an explicit loop.
99+
"""
158100
function varwise_logpriors(
159101
model::Model, varinfo::AbstractVarInfo,
160102
context::AbstractContext=PriorContext()

src/test_utils.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,10 @@ function TestLogModifyingChildContext(
10591059
mod, context
10601060
)
10611061
end
1062-
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
1062+
# Samplers call leafcontext(model.context) when evaluating log-densities
1063+
# Hence, in order to be used need to say that its a leaf-context
1064+
#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
1065+
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf()
10631066
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
10641067
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
10651068
return TestLogModifyingChildContext(context.mod, child)
@@ -1074,5 +1077,14 @@ function DynamicPPL.dot_tilde_assume(context::TestLogModifyingChildContext, righ
10741077
value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi)
10751078
return value, logp*context.mod, vi
10761079
end
1080+
function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
1081+
value, logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
1082+
return value, logp*context.mod, vi
1083+
end
1084+
function DynamicPPL.dot_tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
1085+
return DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
1086+
return value, logp*context.mod, vi
1087+
end
1088+
10771089

10781090
end

test/loglikelihoods.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ end
2828
mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx)
2929
#m = DynamicPPL.TestUtils.DEMO_MODELS[1]
3030
# m = DynamicPPL.TestUtils.demo_assume_index_observe() # logp at i-level?
31+
# m = DynamicPPL.TestUtils.demo_assume_dot_observe() # failing test?
3132
@testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS)
3233
#@show i
3334
example_values = DynamicPPL.TestUtils.rand_prior_true(m)
@@ -53,7 +54,8 @@ end
5354
#
5455
# test on modifying child-context
5556
logpriors_mod = DynamicPPL.varwise_logpriors(m, vi, mod_ctx2)
56-
logp1 = getlogp(vi)
57+
logp1 = getlogp(vi)
58+
#logp_mod = logprior(m, vi) # uses prior context but not mod_ctx2
5759
# Following line assumes no Likelihood contributions
5860
# requires lowest Context to be PriorContext
5961
@test !isfinite(logp1) || sum(x -> sum(x), values(logpriors_mod)) logp1 #
@@ -62,25 +64,24 @@ end
6264
end;
6365

6466
@testset "logpriors_var chain" begin
65-
@model function demo(xs, y)
67+
@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
6668
s ~ InverseGamma(2, 3)
67-
m ~ Normal(0, s)
68-
for i in eachindex(xs)
69-
xs[i] ~ Normal(m, s)
69+
m = TV(undef, length(x))
70+
for i in eachindex(x)
71+
m[i] ~ Normal(0, s)
7072
end
71-
y ~ Normal(m, s)
72-
end
73-
xs_true, y_true = ([0.3290767977680923, 0.038972110187911684, -0.5797496780649221], -0.7321425592768186)#randn(3), randn()
74-
model = demo(xs_true, y_true)
73+
x ~ MvNormal(m, s)
74+
end
75+
x_true = [0.3290767977680923, 0.038972110187911684, -0.5797496780649221]
76+
model = demo(x_true)
7577
() -> begin
7678
# generate the sample used below
77-
chain = sample(model, MH(), 10)
78-
arr0 = Array(chain)
79+
chain = sample(model, MH(), MCMCThreads(), 10, 2)
80+
arr0 = stack(Array(chain, append_chains=false))
81+
@show(arr0);
7982
end
80-
arr0 = [1.8585322626573435 -0.05900855284939967; 1.7304068220366808 -0.6386249100228161; 1.7304068220366808 -0.6386249100228161; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039; 0.8732539292509538 -0.004885395480653039]; # generated in function above
81-
# split into two chains for testing
82-
arr1 = permutedims(reshape(arr0, 5,2,:),(1,3,2))
83-
chain = Chains(arr1, [:s, :m]);
83+
arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939]
84+
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]);
8485
tmp1 = varwise_logpriors(model, chain)
8586
tmp = Chains(tmp1...); # can be used to create a Chains object
8687
vi = VarInfo(model)

0 commit comments

Comments
 (0)