@@ -29,6 +29,42 @@ function Base.push!(
29
29
ctx. loglikelihoods[vn] = logp
30
30
end
31
31
32
+ function Base. push! (
33
+ ctx:: ElementwiseLikelihoodContext{Dict{String, Vector{Float64}}} ,
34
+ vn:: VarName ,
35
+ logp:: Real
36
+ )
37
+ lookup = ctx. loglikelihoods
38
+ ℓ = get! (lookup, string (vn), Float64[])
39
+ push! (ℓ, logp)
40
+ end
41
+
42
+ function Base. push! (
43
+ ctx:: ElementwiseLikelihoodContext{Dict{String, Float64}} ,
44
+ vn:: VarName ,
45
+ logp:: Real
46
+ )
47
+ ctx. loglikelihoods[string (vn)] = logp
48
+ end
49
+
50
+ function Base. push! (
51
+ ctx:: ElementwiseLikelihoodContext{Dict{String, Vector{Float64}}} ,
52
+ vn:: String ,
53
+ logp:: Real
54
+ )
55
+ lookup = ctx. loglikelihoods
56
+ ℓ = get! (lookup, vn, Float64[])
57
+ push! (ℓ, logp)
58
+ end
59
+
60
+ function Base. push! (
61
+ ctx:: ElementwiseLikelihoodContext{Dict{String, Float64}} ,
62
+ vn:: String ,
63
+ logp:: Real
64
+ )
65
+ ctx. loglikelihoods[vn] = logp
66
+ end
67
+
32
68
33
69
function tilde_assume (rng, ctx:: ElementwiseLikelihoodContext , sampler, right, vn, inds, vi)
34
70
return tilde_assume (rng, ctx. ctx, sampler, right, vn, inds, vi)
56
92
57
93
58
94
"""
59
- elementwise_loglikelihoods(model::Model, chain::Chains)
95
+ elementwise_loglikelihoods(model::Model, chain::Chains, keytype = String )
60
96
61
- Runs `model` on each sample in `chain` returning an array of arrays with
62
- the i-th element inner arrays corresponding to the the likelihood of the i-th
63
- observation for that particular sample in `chain`.
97
+ Runs `model` on each sample in `chain` returning a `Dict{String, Matrix{Float64}}`
98
+ with keys corresponding to symbols of the observations, and values being matrices
99
+ of shape `(num_chains, num_samples)`.
100
+
101
+ `keytype` specifies what the type of the keys used in the returned `Dict` are.
102
+ Currently, only `String` and `VarName` are supported.
64
103
65
104
# Notes
66
105
Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
75
114
```julia
76
115
y ~ MvNormal(fill(μ, n), fill(σ, n))
77
116
```
78
- Unfortunately, just by looking at the latter statement, it's impossible to tell whether or
79
- not this is one *single* observation which is `n` dimensional OR if we have *multiple*
80
- 1-dimensional observations. Therefore, `loglikelihoods` will only work with the first
81
- example.
117
+ Unfortunately, just by looking at the latter statement, it's impossible to tell
118
+ whether or not this is one *single* observation which is `n` dimensional OR if we
119
+ have *multiple* 1-dimensional observations. Therefore, `loglikelihoods` will only
120
+ work with the first example.
82
121
83
122
# Examples
84
123
```julia-repl
@@ -99,19 +138,37 @@ julia> model = demo(randn(3), randn());
99
138
100
139
julia> chain = sample(model, MH(), 10);
101
140
102
- julia> DynamicPPL.elementwise_loglikelihoods(model, chain)
103
- Dict{String,Array{Float64,1}} with 4 entries:
104
- "xs[3]" => [-1.02616, -1.26931, -1.05003, -5.05458, -1.33825, -1.02904, -1.23761, -1.30128, -1.04872, -2.03716]
105
- "xs[1]" => [-2.08205, -2.51387, -3.03175, -2.5981, -2.31322, -2.62284, -2.70874, -1.18617, -1.36281, -4.39839]
106
- "xs[2]" => [-2.20604, -2.63495, -3.22802, -2.48785, -2.40941, -2.78791, -2.85013, -1.24081, -1.46019, -4.59025]
107
- "y" => [-1.36627, -1.21964, -1.03342, -7.46617, -1.3234, -1.14536, -1.14781, -2.48912, -2.23705, -1.26267]
141
+ julia> elementwise_loglikelihoods(model, chain)
142
+ Dict{String,Array{Float64,2}} with 4 entries:
143
+ "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
144
+ "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
145
+ "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
146
+ "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
147
+
148
+ julia> elementwise_loglikelihoods(model, chain, String)
149
+ Dict{String,Array{Float64,2}} with 4 entries:
150
+ "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
151
+ "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
152
+ "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
153
+ "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
154
+
155
+ julia> elementwise_loglikelihoods(model, chain, VarName)
156
+ Dict{VarName,Array{Float64,2}} with 4 entries:
157
+ xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
158
+ y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
159
+ xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
160
+ xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
108
161
```
109
162
"""
110
- function elementwise_loglikelihoods (model:: Model , chain)
163
+ function elementwise_loglikelihoods (
164
+ model:: Model ,
165
+ chain,
166
+ keytype:: Type{T} = String
167
+ ) where {T}
111
168
# Get the data by executing the model once
112
- ctx = ElementwiseLikelihoodContext ()
113
169
spl = SampleFromPrior ()
114
170
vi = VarInfo (model)
171
+ ctx = ElementwiseLikelihoodContext (Dict {T, Vector{Float64}} ())
115
172
116
173
iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
117
174
for (sample_idx, chain_idx) in iters
@@ -121,7 +178,14 @@ function elementwise_loglikelihoods(model::Model, chain)
121
178
# Execute model
122
179
model (vi, spl, ctx)
123
180
end
124
- return ctx. loglikelihoods
181
+
182
+ niters = size (chain, 1 )
183
+ nchains = size (chain, 3 )
184
+ loglikelihoods = Dict (
185
+ varname => reshape (logliks, niters, nchains)
186
+ for (varname, logliks) in ctx. loglikelihoods
187
+ )
188
+ return loglikelihoods
125
189
end
126
190
127
191
function elementwise_loglikelihoods (model:: Model , varinfo:: AbstractVarInfo )
0 commit comments