@@ -40,13 +40,8 @@ function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi)
4040end
4141
4242
43- function tilde_observe (context:: VarwisePriorContext , right, left, vi)
44- # Since we are evaluating the prior, the log probability of all the observations
45- # is set to 0. This has the effect of ignoring the likelihood.
46- return 0.0 , vi
47- # tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi)
48- # return tmp
49- end
43+ tilde_observe (context:: VarwisePriorContext , right, left, vi) = 0 , vi
44+ dot_tilde_observe (:: VarwisePriorContext , right, left, vi) = 0 , vi
5045
5146function acc_logp! (context:: VarwisePriorContext , vn:: Union{VarName,AbstractVector{<:VarName}} , logp)
5247 # sym = DynamicPPL.getsym(vn) # leads to duplicates
@@ -56,105 +51,52 @@ function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVecto
5651 return (context)
5752end
5853
59-
60- # """
61- # pointwise_logpriors(model::Model, chain::Chains, keytype = String)
62-
63- # Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
64- # with keys corresponding to symbols of the observations, and values being matrices
65- # of shape `(num_chains, num_samples)`.
66-
67- # `keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
68- # Currently, only `String` and `VarName` are supported.
69-
70- # # Notes
71- # Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
72- # both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an
73- # *observation*) statements can be implemented in three ways:
74- # 1. using a `for` loop:
75- # ```julia
76- # for i in eachindex(y)
77- # y[i] ~ Normal(μ, σ)
78- # end
79- # ```
80- # 2. using `.~`:
81- # ```julia
82- # y .~ Normal(μ, σ)
83- # ```
84- # 3. using `MvNormal`:
85- # ```julia
86- # y ~ MvNormal(fill(μ, n), σ^2 * I)
87- # ```
88-
89- # In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
90- # while in (3) `y` will be treated as a _single_ n-dimensional observation.
91-
92- # This is important to keep in mind, in particular if the computation is used
93- # for downstream computations.
94-
95- # # Examples
96- # ## From chain
97- # ```julia-repl
98- # julia> using DynamicPPL, Turing
99-
100- # julia> @model function demo(xs, y)
101- # s ~ InverseGamma(2, 3)
102- # m ~ Normal(0, √s)
103- # for i in eachindex(xs)
104- # xs[i] ~ Normal(m, √s)
105- # end
106-
107- # y ~ Normal(m, √s)
108- # end
109- # demo (generic function with 1 method)
110-
111- # julia> model = demo(randn(3), randn());
112-
113- # julia> chain = sample(model, MH(), 10);
114-
115- # julia> pointwise_logpriors(model, chain)
116- # OrderedDict{String,Array{Float64,2}} with 4 entries:
117- # "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
118- # "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
119- # "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
120- # "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
121-
122- # julia> pointwise_logpriors(model, chain, String)
123- # OrderedDict{String,Array{Float64,2}} with 4 entries:
124- # "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
125- # "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
126- # "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
127- # "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
128-
129- # julia> pointwise_logpriors(model, chain, VarName)
130- # OrderedDict{VarName,Array{Float64,2}} with 4 entries:
131- # xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
132- # xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
133- # xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
134- # y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
135- # ```
136-
137- # ## Broadcasting
138- # Note that `x .~ Dist()` will treat `x` as a collection of
139- # _independent_ observations rather than as a single observation.
140-
141- # ```jldoctest; setup = :(using Distributions)
142- # julia> @model function demo(x)
143- # x .~ Normal()
144- # end;
145-
146- # julia> m = demo([1.0, ]);
147-
148- # julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])])
149- # -1.4189385332046727
150-
151- # julia> m = demo([1.0; 1.0]);
152-
153- # julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
154- # (-1.4189385332046727, -1.4189385332046727)
155- # ```
156-
157- # """
54+ """
55+ varwise_logpriors(model::Model, chain::Chains; context)
56+
57+ Runs `model` on each sample in `chain` returning a tuple `(values, var_names)`
58+ with var_names corresponding to symbols of the prior components, and values being
59+ array of shape `(num_samples, num_components, num_chains)`.
60+
61+ `context` specifies child context that handles computation of log-priors.
62+
63+ # Example
64+ ```julia; setup = :(using Distributions)
65+ using DynamicPPL, Turing
66+
67+ @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
68+ s ~ InverseGamma(2, 3)
69+ m = TV(undef, length(x))
70+ for i in eachindex(x)
71+ m[i] ~ Normal(0, √s)
72+ end
73+ x ~ MvNormal(m, √s)
74+ end
75+
76+ model = demo(randn(3), randn());
77+
78+ chain = sample(model, MH(), 10);
79+
80+ lp = varwise_logpriors(model, chain)
81+ # Can be used to construct a new Chains object
82+ #lpc = MCMCChains(varwise_logpriors(model, chain)...)
83+
84+ # got a logdensity for each parameter prior
85+ (but fewer if used `.~` assignments, see below)
86+ lp[2] == names(chain, :parameters)
87+
88+ # for each sample in the Chains object
89+ size(lp[1])[[1,3]] == size(chain)[[1,3]]
90+ ```
91+
92+ # Broadcasting
93+ Note that `m .~ Dist()` will treat `m` as a collection of
94+ _independent_ prior rather than as a single prior,
95+ but `varwise_logpriors` returns the single
96+ sum of log-likelihood of components of `m` only.
97+ If one needs the log-density of the components, one needs to rewrite
98+ the model with an explicit loop.
99+ """
158100function varwise_logpriors (
159101 model:: Model , varinfo:: AbstractVarInfo ,
160102 context:: AbstractContext = PriorContext ()
0 commit comments