Skip to content

Commit 161f820

Browse files
committed
generated_quantities and fix for #167 (#168)
This PR adds `generated_quantities` as discussed in TuringLang/Turing.jl#1335 + adds a fix for #167.
1 parent 4b6d95d commit 161f820

File tree

5 files changed

+199
-3
lines changed

5 files changed

+199
-3
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export AbstractVarInfo,
6666
Model,
6767
getmissings,
6868
getargnames,
69+
generated_quantities,
6970
# Samplers
7071
Sampler,
7172
SampleFromPrior,

src/model.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,71 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
200200
model(varinfo, SampleFromPrior(), LikelihoodContext())
201201
return getlogp(varinfo)
202202
end
203+
204+
"""
205+
generated_quantities(model::Model, chain::AbstractChains)
206+
207+
Execute `model` for each of the samples in `chain` and return an array of the values
208+
returned by the `model` for each sample.
209+
210+
# Examples
211+
## General
212+
Often you might have additional quantities computed inside the model that you want to
213+
inspect, e.g.
214+
```julia
215+
@model function demo(x)
216+
# sample and observe
217+
θ ~ Prior()
218+
x ~ Likelihood()
219+
return interesting_quantity(θ, x)
220+
end
221+
m = demo(data)
222+
chain = sample(m, alg, n)
223+
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
224+
# from the posterior/`chain`:
225+
generated_quantities(m, chain) # <= results in a `Vector` of returned values
226+
# from `interesting_quantity(θ, x)`
227+
```
228+
## Concrete (and simple)
229+
```julia
230+
julia> using DynamicPPL, Turing
231+
232+
julia> @model function demo(xs)
233+
s ~ InverseGamma(2, 3)
234+
m_shifted ~ Normal(10, √s)
235+
m = m_shifted - 10
236+
237+
for i in eachindex(xs)
238+
xs[i] ~ Normal(m, √s)
239+
end
240+
241+
return (m, )
242+
end
243+
demo (generic function with 1 method)
244+
245+
julia> model = demo(randn(10));
246+
247+
julia> chain = sample(model, MH(), 10);
248+
249+
julia> generated_quantities(model, chain)
250+
10×1 Array{Tuple{Float64},2}:
251+
(2.1964758025119338,)
252+
(2.1964758025119338,)
253+
(0.09270081916291417,)
254+
(0.09270081916291417,)
255+
(0.09270081916291417,)
256+
(0.09270081916291417,)
257+
(0.09270081916291417,)
258+
(0.043088571494005024,)
259+
(-0.16489786710222099,)
260+
(-0.16489786710222099,)
261+
```
262+
"""
263+
function generated_quantities(model::Model, chain::AbstractChains)
264+
varinfo = VarInfo(model)
265+
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
266+
return map(iters) do (sample_idx, chain_idx)
267+
setval!(varinfo, chain, sample_idx, chain_idx)
268+
model(varinfo)
269+
end
270+
end

src/varinfo.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,9 +1178,12 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value
11781178
end
11791179

11801180
function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
1181-
sym = Symbol(vn)
1182-
regex = Regex("^$sym\$|^$sym\\[")
1183-
indices = findall(x -> match(regex, string(x)) !== nothing, keys)
1181+
string_vn = string(vn)
1182+
string_vn_indexing = string_vn * "["
1183+
indices = findall(keys) do x
1184+
string_x = string(x)
1185+
return string_x == string_vn || startswith(string_x, string_vn_indexing)
1186+
end
11841187
if !isempty(indices)
11851188
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
11861189
val = mapreduce(vcat, sorted_indices) do i

test/model.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,98 @@ Random.seed!(1234)
4444
end
4545
end
4646
end
47+
48+
@testset "setval! & generated_quantities" begin
49+
@model function demo1(xs, ::Type{TV} = Vector{Float64}) where {TV}
50+
m = TV(undef, 2)
51+
for i in 1:2
52+
m[i] ~ Normal(0, 1)
53+
end
54+
55+
for i in eachindex(xs)
56+
xs[i] ~ Normal(m[1], 1.)
57+
end
58+
59+
return (m, )
60+
end
61+
62+
@model function demo2(xs)
63+
m ~ MvNormal(2, 1.)
64+
65+
for i in eachindex(xs)
66+
xs[i] ~ Normal(m[1], 1.)
67+
end
68+
69+
return (m, )
70+
end
71+
72+
xs = randn(3);
73+
model1 = demo1(xs);
74+
model2 = demo2(xs);
75+
76+
chain1 = sample(model1, MH(), 100);
77+
chain2 = sample(model2, MH(), 100);
78+
79+
res11 = generated_quantities(model1, chain1)
80+
res21 = generated_quantities(model2, chain1)
81+
82+
res12 = generated_quantities(model1, chain2)
83+
res22 = generated_quantities(model2, chain2)
84+
85+
# Check that the two different models produce the same values for
86+
# the same chains.
87+
@test all(res11 .== res21)
88+
@test all(res12 .== res22)
89+
# Ensure that they're not all the same (some can be, because rejected samples)
90+
@test any(res12[1:end - 1] .!= res12[2:end])
91+
92+
test_setval!(model1, chain1)
93+
test_setval!(model2, chain2)
94+
95+
# Next level
96+
@model function demo3(xs, ::Type{TV} = Vector{Float64}) where {TV}
97+
m = Vector{TV}(undef, 2)
98+
for i = 1:length(m)
99+
m[i] ~ MvNormal(2, 1.)
100+
end
101+
102+
for i in eachindex(xs)
103+
xs[i] ~ Normal(m[1][1], 1.)
104+
end
105+
106+
return (m, )
107+
end
108+
109+
@model function demo4(xs, ::Type{TV} = Vector{Vector{Float64}}) where {TV}
110+
m = TV(undef, 2)
111+
for i = 1:length(m)
112+
m[i] ~ MvNormal(2, 1.)
113+
end
114+
115+
for i in eachindex(xs)
116+
xs[i] ~ Normal(m[1][1], 1.)
117+
end
118+
119+
return (m, )
120+
end
121+
122+
model3 = demo3(xs);
123+
model4 = demo4(xs);
124+
125+
chain3 = sample(model3, MH(), 100);
126+
chain4 = sample(model4, MH(), 100);
127+
128+
res33 = generated_quantities(model3, chain3)
129+
res43 = generated_quantities(model4, chain3)
130+
131+
res34 = generated_quantities(model3, chain4)
132+
res44 = generated_quantities(model4, chain4)
133+
134+
# Check that the two different models produce the same values for
135+
# the same chains.
136+
@test all(res33 .== res43)
137+
@test all(res34 .== res44)
138+
# Ensure that they're not all the same (some can be, because rejected samples)
139+
@test any(res34[1:end - 1] .!= res34[2:end])
140+
end
47141
end

test/test_util.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,33 @@ function test_model_ad(model, logp_manual)
3636
@test y lp
3737
@test back(1)[1] grad
3838
end
39+
40+
41+
"""
42+
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
43+
44+
Test `setval!` on `model` and `chain`.
45+
46+
Worth noting that this only supports models containing symbols of the forms
47+
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
48+
"""
49+
function test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
50+
var_info = VarInfo(model)
51+
spl = SampleFromPrior()
52+
θ_old = var_info[spl]
53+
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
54+
θ_new = var_info[spl]
55+
@test θ_old != θ_new
56+
nt = DynamicPPL.tonamedtuple(var_info)
57+
for (k, (vals, names)) in pairs(nt)
58+
for (n, v) in zip(names, vals)
59+
chain_val = if Symbol(n) keys(chain)
60+
# Assume it's a group
61+
vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx])
62+
else
63+
chain[sample_idx, n, chain_idx]
64+
end
65+
@test v == chain_val
66+
end
67+
end
68+
end

0 commit comments

Comments
 (0)