1
1
# Context version
2
- struct ElementwiseLikelihoodContext {A, Ctx} <: AbstractContext
2
+ struct PointwiseLikelihoodContext {A, Ctx} <: AbstractContext
3
3
loglikelihoods:: A
4
4
ctx:: Ctx
5
5
end
6
6
7
- function ElementwiseLikelihoodContext (
7
+ function PointwiseLikelihoodContext (
8
8
likelihoods = Dict {VarName, Vector{Float64}} (),
9
9
ctx:: AbstractContext = LikelihoodContext ()
10
10
)
11
- return ElementwiseLikelihoodContext {typeof(likelihoods),typeof(ctx)} (likelihoods, ctx)
11
+ return PointwiseLikelihoodContext {typeof(likelihoods),typeof(ctx)} (likelihoods, ctx)
12
12
end
13
13
14
14
function Base. push! (
15
- ctx:: ElementwiseLikelihoodContext {Dict{VarName, Vector{Float64}}} ,
15
+ ctx:: PointwiseLikelihoodContext {Dict{VarName, Vector{Float64}}} ,
16
16
vn:: VarName ,
17
17
logp:: Real
18
18
)
@@ -22,15 +22,15 @@ function Base.push!(
22
22
end
23
23
24
24
function Base. push! (
25
- ctx:: ElementwiseLikelihoodContext {Dict{VarName, Float64}} ,
25
+ ctx:: PointwiseLikelihoodContext {Dict{VarName, Float64}} ,
26
26
vn:: VarName ,
27
27
logp:: Real
28
28
)
29
29
ctx. loglikelihoods[vn] = logp
30
30
end
31
31
32
32
function Base. push! (
33
- ctx:: ElementwiseLikelihoodContext {Dict{String, Vector{Float64}}} ,
33
+ ctx:: PointwiseLikelihoodContext {Dict{String, Vector{Float64}}} ,
34
34
vn:: VarName ,
35
35
logp:: Real
36
36
)
@@ -40,15 +40,15 @@ function Base.push!(
40
40
end
41
41
42
42
function Base. push! (
43
- ctx:: ElementwiseLikelihoodContext {Dict{String, Float64}} ,
43
+ ctx:: PointwiseLikelihoodContext {Dict{String, Float64}} ,
44
44
vn:: VarName ,
45
45
logp:: Real
46
46
)
47
47
ctx. loglikelihoods[string (vn)] = logp
48
48
end
49
49
50
50
function Base. push! (
51
- ctx:: ElementwiseLikelihoodContext {Dict{String, Vector{Float64}}} ,
51
+ ctx:: PointwiseLikelihoodContext {Dict{String, Vector{Float64}}} ,
52
52
vn:: String ,
53
53
logp:: Real
54
54
)
@@ -58,26 +58,26 @@ function Base.push!(
58
58
end
59
59
60
60
function Base. push! (
61
- ctx:: ElementwiseLikelihoodContext {Dict{String, Float64}} ,
61
+ ctx:: PointwiseLikelihoodContext {Dict{String, Float64}} ,
62
62
vn:: String ,
63
63
logp:: Real
64
64
)
65
65
ctx. loglikelihoods[vn] = logp
66
66
end
67
67
68
68
69
- function tilde_assume (rng, ctx:: ElementwiseLikelihoodContext , sampler, right, vn, inds, vi)
69
+ function tilde_assume (rng, ctx:: PointwiseLikelihoodContext , sampler, right, vn, inds, vi)
70
70
return tilde_assume (rng, ctx. ctx, sampler, right, vn, inds, vi)
71
71
end
72
72
73
- function dot_tilde_assume (rng, ctx:: ElementwiseLikelihoodContext , sampler, right, left, vn, inds, vi)
73
+ function dot_tilde_assume (rng, ctx:: PointwiseLikelihoodContext , sampler, right, left, vn, inds, vi)
74
74
value, logp = dot_tilde (rng, ctx. ctx, sampler, right, left, vn, inds, vi)
75
75
acclogp! (vi, logp)
76
76
return value
77
77
end
78
78
79
79
80
- function tilde_observe (ctx:: ElementwiseLikelihoodContext , sampler, right, left, vname, vinds, vi)
80
+ function tilde_observe (ctx:: PointwiseLikelihoodContext , sampler, right, left, vname, vinds, vi)
81
81
# This is slightly unfortunate since it is not completely generic...
82
82
# Ideally we would call `tilde_observe` recursively but then we don't get the
83
83
# loglikelihood value.
92
92
93
93
94
94
"""
95
- elementwise_loglikelihoods (model::Model, chain::Chains, keytype = String)
95
+ pointwise_loglikelihoods (model::Model, chain::Chains, keytype = String)
96
96
97
97
Runs `model` on each sample in `chain` returning a `Dict{String, Matrix{Float64}}`
98
98
with keys corresponding to symbols of the observations, and values being matrices
@@ -138,37 +138,37 @@ julia> model = demo(randn(3), randn());
138
138
139
139
julia> chain = sample(model, MH(), 10);
140
140
141
- julia> elementwise_loglikelihoods (model, chain)
141
+ julia> pointwise_loglikelihoods (model, chain)
142
142
Dict{String,Array{Float64,2}} with 4 entries:
143
143
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
144
144
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
145
145
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
146
146
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
147
147
148
- julia> elementwise_loglikelihoods (model, chain, String)
148
+ julia> pointwise_loglikelihoods (model, chain, String)
149
149
Dict{String,Array{Float64,2}} with 4 entries:
150
150
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
151
151
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
152
152
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
153
153
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
154
154
155
- julia> elementwise_loglikelihoods (model, chain, VarName)
155
+ julia> pointwise_loglikelihoods (model, chain, VarName)
156
156
Dict{VarName,Array{Float64,2}} with 4 entries:
157
157
xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
158
158
y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
159
159
xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
160
160
xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
161
161
```
162
162
"""
163
- function elementwise_loglikelihoods (
163
+ function pointwise_loglikelihoods (
164
164
model:: Model ,
165
165
chain,
166
166
keytype:: Type{T} = String
167
167
) where {T}
168
168
# Get the data by executing the model once
169
169
spl = SampleFromPrior ()
170
170
vi = VarInfo (model)
171
- ctx = ElementwiseLikelihoodContext (Dict {T, Vector{Float64}} ())
171
+ ctx = PointwiseLikelihoodContext (Dict {T, Vector{Float64}} ())
172
172
173
173
iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
174
174
for (sample_idx, chain_idx) in iters
@@ -188,8 +188,8 @@ function elementwise_loglikelihoods(
188
188
return loglikelihoods
189
189
end
190
190
191
- function elementwise_loglikelihoods (model:: Model , varinfo:: AbstractVarInfo )
192
- ctx = ElementwiseLikelihoodContext (Dict {VarName, Float64} ())
191
+ function pointwise_loglikelihoods (model:: Model , varinfo:: AbstractVarInfo )
192
+ ctx = PointwiseLikelihoodContext (Dict {VarName, Float64} ())
193
193
model (varinfo, SampleFromPrior (), ctx)
194
194
return ctx. loglikelihoods
195
195
end
0 commit comments