@@ -40,13 +40,8 @@ function dot_tilde_assume(context::VarwisePriorContext, right, left, vn, vi)
40
40
end
41
41
42
42
43
- function tilde_observe (context:: VarwisePriorContext , right, left, vi)
44
- # Since we are evaluating the prior, the log probability of all the observations
45
- # is set to 0. This has the effect of ignoring the likelihood.
46
- return 0.0 , vi
47
- # tmp = tilde_observe(context.context, SampleFromPrior(), right, left, vi)
48
- # return tmp
49
- end
43
+ tilde_observe (context:: VarwisePriorContext , right, left, vi) = 0 , vi
44
+ dot_tilde_observe (:: VarwisePriorContext , right, left, vi) = 0 , vi
50
45
51
46
function acc_logp! (context:: VarwisePriorContext , vn:: Union{VarName,AbstractVector{<:VarName}} , logp)
52
47
# sym = DynamicPPL.getsym(vn) # leads to duplicates
@@ -56,105 +51,52 @@ function acc_logp!(context::VarwisePriorContext, vn::Union{VarName,AbstractVecto
56
51
return (context)
57
52
end
58
53
59
-
60
- # """
61
- # pointwise_logpriors(model::Model, chain::Chains, keytype = String)
62
-
63
- # Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
64
- # with keys corresponding to symbols of the observations, and values being matrices
65
- # of shape `(num_chains, num_samples)`.
66
-
67
- # `keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
68
- # Currently, only `String` and `VarName` are supported.
69
-
70
- # # Notes
71
- # Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
72
- # both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an
73
- # *observation*) statements can be implemented in three ways:
74
- # 1. using a `for` loop:
75
- # ```julia
76
- # for i in eachindex(y)
77
- # y[i] ~ Normal(μ, σ)
78
- # end
79
- # ```
80
- # 2. using `.~`:
81
- # ```julia
82
- # y .~ Normal(μ, σ)
83
- # ```
84
- # 3. using `MvNormal`:
85
- # ```julia
86
- # y ~ MvNormal(fill(μ, n), σ^2 * I)
87
- # ```
88
-
89
- # In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
90
- # while in (3) `y` will be treated as a _single_ n-dimensional observation.
91
-
92
- # This is important to keep in mind, in particular if the computation is used
93
- # for downstream computations.
94
-
95
- # # Examples
96
- # ## From chain
97
- # ```julia-repl
98
- # julia> using DynamicPPL, Turing
99
-
100
- # julia> @model function demo(xs, y)
101
- # s ~ InverseGamma(2, 3)
102
- # m ~ Normal(0, √s)
103
- # for i in eachindex(xs)
104
- # xs[i] ~ Normal(m, √s)
105
- # end
106
-
107
- # y ~ Normal(m, √s)
108
- # end
109
- # demo (generic function with 1 method)
110
-
111
- # julia> model = demo(randn(3), randn());
112
-
113
- # julia> chain = sample(model, MH(), 10);
114
-
115
- # julia> pointwise_logpriors(model, chain)
116
- # OrderedDict{String,Array{Float64,2}} with 4 entries:
117
- # "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
118
- # "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
119
- # "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
120
- # "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
121
-
122
- # julia> pointwise_logpriors(model, chain, String)
123
- # OrderedDict{String,Array{Float64,2}} with 4 entries:
124
- # "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
125
- # "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
126
- # "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
127
- # "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
128
-
129
- # julia> pointwise_logpriors(model, chain, VarName)
130
- # OrderedDict{VarName,Array{Float64,2}} with 4 entries:
131
- # xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
132
- # xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
133
- # xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
134
- # y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
135
- # ```
136
-
137
- # ## Broadcasting
138
- # Note that `x .~ Dist()` will treat `x` as a collection of
139
- # _independent_ observations rather than as a single observation.
140
-
141
- # ```jldoctest; setup = :(using Distributions)
142
- # julia> @model function demo(x)
143
- # x .~ Normal()
144
- # end;
145
-
146
- # julia> m = demo([1.0, ]);
147
-
148
- # julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first(ℓ[@varname(x[1])])
149
- # -1.4189385332046727
150
-
151
- # julia> m = demo([1.0; 1.0]);
152
-
153
- # julia> ℓ = pointwise_logpriors(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
154
- # (-1.4189385332046727, -1.4189385332046727)
155
- # ```
156
-
157
- # """
54
+ """
55
+ varwise_logpriors(model::Model, chain::Chains; context)
56
+
57
+ Runs `model` on each sample in `chain` returning a tuple `(values, var_names)`
58
+ with var_names corresponding to symbols of the prior components, and values being
59
+ array of shape `(num_samples, num_components, num_chains)`.
60
+
61
+ `context` specifies child context that handles computation of log-priors.
62
+
63
+ # Example
64
+ ```julia; setup = :(using Distributions)
65
+ using DynamicPPL, Turing
66
+
67
+ @model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
68
+ s ~ InverseGamma(2, 3)
69
+ m = TV(undef, length(x))
70
+ for i in eachindex(x)
71
+ m[i] ~ Normal(0, √s)
72
+ end
73
+ x ~ MvNormal(m, √s)
74
+ end
75
+
76
+ model = demo(randn(3), randn());
77
+
78
+ chain = sample(model, MH(), 10);
79
+
80
+ lp = varwise_logpriors(model, chain)
81
+ # Can be used to construct a new Chains object
82
+ #lpc = MCMCChains(varwise_logpriors(model, chain)...)
83
+
84
+ # got a logdensity for each parameter prior
85
+ (but fewer if used `.~` assignments, see below)
86
+ lp[2] == names(chain, :parameters)
87
+
88
+ # for each sample in the Chains object
89
+ size(lp[1])[[1,3]] == size(chain)[[1,3]]
90
+ ```
91
+
92
+ # Broadcasting
93
+ Note that `m .~ Dist()` will treat `m` as a collection of
94
+ _independent_ prior rather than as a single prior,
95
+ but `varwise_logpriors` returns the single
96
+ sum of log-likelihood of components of `m` only.
97
+ If one needs the log-density of the components, one needs to rewrite
98
+ the model with an explicit loop.
99
+ """
158
100
function varwise_logpriors (
159
101
model:: Model , varinfo:: AbstractVarInfo ,
160
102
context:: AbstractContext = PriorContext ()
0 commit comments