Skip to content

Commit ba206f4

Browse files
torfjeldeyebaigithub-actions[bot]
authored
Added functionality for extracting parameter values for a model from chain (#481)
* added methods for extracting parameter values for a model from a given chain * added MCMCchains as a dep to docs * attempt at fixing doctests * remove the doctest as it's not working for some reason * added docs * Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed incorrect function call as pointed out by @YongchaoHuang * moved out argument to the end up of the function signature --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e241dca commit ba206f4

File tree

4 files changed

+234
-1
lines changed

4 files changed

+234
-1
lines changed

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob
130130
pointwise_loglikelihoods
131131
```
132132

133+
For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref).
134+
135+
```@docs
136+
value_iterator_from_chain
137+
138+
```
139+
133140
Sometimes it can be useful to extract the priors of a model. This is the possible using [`extract_priors`](@ref).
134141

135142
```@docs

src/DynamicPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ export AbstractVarInfo,
123123
unfix,
124124
# Convenience macros
125125
@addlogprob!,
126-
@submodel
126+
@submodel,
127+
value_iterator_from_chain
127128

128129
# Reexport
129130
using Distributions: loglikelihood
@@ -169,6 +170,7 @@ include("submodel_macro.jl")
169170
include("test_utils.jl")
170171
include("transforming.jl")
171172
include("logdensityfunction.jl")
173+
include("model_utils.jl")
172174
include("extract_priors.jl")
173175

174176
end # module

src/model_utils.jl

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
varnames_in_chain(model:::Model, chain)
3+
varnames_in_chain(varinfo::VarInfo, chain)
4+
5+
Return `true` if all variable names in `model`/`varinfo` are in `chain`.
6+
"""
7+
varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain)
8+
function varnames_in_chain(varinfo::VarInfo, chain)
9+
return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo))
10+
end
11+
12+
"""
13+
varnames_in_chain!(model::Model, chain, out)
14+
varnames_in_chain!(varinfo::VarInfo, chain, out)
15+
16+
Return `out` with `true` for all variable names in `model` that are in `chain`.
17+
"""
18+
function varnames_in_chain!(model::Model, chain, out)
19+
return varnames_in_chain!(VarInfo(model), chain, out)
20+
end
21+
function varnames_in_chain!(varinfo::VarInfo, chain, out)
22+
for vn in keys(varinfo)
23+
varname_in_chain!(varinfo, vn, chain, 1, 1, out)
24+
end
25+
26+
return out
27+
end
28+
29+
"""
30+
varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx)
31+
varname_in_chain(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx)
32+
33+
Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`.
34+
"""
35+
function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx)
36+
return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx)
37+
end
38+
39+
function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx)
40+
!haskey(varinfo, vn) && return false
41+
return varname_in_chain(varinfo[vn], vn, chain, chain_idx, iteration_idx)
42+
end
43+
44+
function varname_in_chain(x, vn, chain, chain_idx, iteration_idx)
45+
out = OrderedDict{VarName,Bool}()
46+
varname_in_chain!(x, vn, chain, chain_idx, iteration_idx, out)
47+
return all(values(out))
48+
end
49+
50+
"""
51+
varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out)
52+
varname_in_chain!(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx, out)
53+
54+
Return a dictionary mapping the varname `vn` to `true` if `vn` is in `chain` at
55+
`chain_idx` and `iteration_idx`.
56+
57+
If `chain_idx` and `iteration_idx` are not provided, then they default to `1`.
58+
59+
This differs from [`varname_in_chain`](@ref) in that it returns a dictionary
60+
rather than a single boolean. This can be quite useful for debugging purposes.
61+
"""
62+
function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out)
63+
return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out)
64+
end
65+
66+
function varname_in_chain!(
67+
vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx, out
68+
)
69+
return varname_in_chain!(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx, out)
70+
end
71+
72+
function varname_in_chain!(x, vn_parent, chain, chain_idx, iteration_idx, out)
73+
sym = Symbol(vn_parent)
74+
out[vn_parent] = sym names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx])
75+
return out
76+
end
77+
78+
function varname_in_chain!(
79+
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out
80+
) where {sym}
81+
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
82+
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
83+
# to extract the value from the `chain`.
84+
for vn in varname_leaves(VarName{sym}(), x)
85+
# Update `out`, possibly in place, and return.
86+
l = AbstractPPL.getlens(vn)
87+
varname_in_chain!(x, vn_parent l, chain, chain_idx, iteration_idx, out)
88+
end
89+
return out
90+
end
91+
92+
"""
93+
values_from_chain(model::Model, chain, chain_idx, iteration_idx)
94+
values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx)
95+
96+
Return a dictionary mapping each variable name in `model`/`varinfo` to its
97+
value in `chain` at `chain_idx` and `iteration_idx`.
98+
"""
99+
function values_from_chain(x, vn_parent, chain, chain_idx, iteration_idx)
100+
# HACK: If it's not an array, we fall back to just returning the first value.
101+
return only(chain[iteration_idx, Symbol(vn_parent), chain_idx])
102+
end
103+
function values_from_chain(
104+
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx
105+
) where {sym}
106+
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
107+
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
108+
# to extract the value from the `chain`.
109+
out = similar(x)
110+
for vn in varname_leaves(VarName{sym}(), x)
111+
# Update `out`, possibly in place, and return.
112+
l = AbstractPPL.getlens(vn)
113+
out = Setfield.set(
114+
out,
115+
BangBang.prefermutation(l),
116+
chain[iteration_idx, Symbol(vn_parent l), chain_idx],
117+
)
118+
end
119+
120+
return out
121+
end
122+
function values_from_chain(vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx)
123+
# Use the value `vi[vn_parent]` to obtain a buffer.
124+
return values_from_chain(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx)
125+
end
126+
127+
"""
128+
values_from_chain!(model::Model, chain, chain_idx, iteration_idx, out)
129+
values_from_chain!(varinfo::VarInfo, chain, chain_idx, iteration_idx, out)
130+
131+
Mutate `out` to map each variable name in `model`/`varinfo` to its value in
132+
`chain` at `chain_idx` and `iteration_idx`.
133+
"""
134+
function values_from_chain!(model::DynamicPPL.Model, chain, chain_idx, iteration_idx, out)
135+
return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx, out)
136+
end
137+
138+
function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out)
139+
for vn in keys(vi)
140+
out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx)
141+
end
142+
return out
143+
end
144+
145+
"""
146+
value_iterator_from_chain(model::Model, chain)
147+
value_iterator_from_chain(varinfo::AbstractVarInfo, chain)
148+
149+
Return an iterator over the values in `chain` for each variable in `model`/`varinfo`.
150+
151+
# Example
152+
```julia
153+
julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs
154+
155+
julia> rng = StableRNG(42);
156+
157+
julia> @model function demo_model(x)
158+
s ~ InverseGamma(2, 3)
159+
m ~ Normal(0, sqrt(s))
160+
for i in eachindex(x)
161+
x[i] ~ Normal(m, sqrt(s))
162+
end
163+
164+
return s, m
165+
end
166+
demo_model (generic function with 2 methods)
167+
168+
julia> model = demo_model([1.0, 2.0]);
169+
170+
julia> chain = Chains(rand(rng, 10, 2, 3), [:s, :m]);
171+
172+
julia> iter = value_iterator_from_chain(model, chain);
173+
174+
julia> first(iter)
175+
OrderedDict{VarName, Any} with 2 entries:
176+
s => 0.580515
177+
m => 0.739328
178+
179+
julia> collect(iter)
180+
10×3 Matrix{OrderedDict{VarName, Any}}:
181+
OrderedDict(s=>0.580515, m=>0.739328) … OrderedDict(s=>0.186047, m=>0.402423)
182+
OrderedDict(s=>0.191241, m=>0.627342) OrderedDict(s=>0.776277, m=>0.166342)
183+
OrderedDict(s=>0.971133, m=>0.637584) OrderedDict(s=>0.651655, m=>0.712044)
184+
OrderedDict(s=>0.74345, m=>0.110359) OrderedDict(s=>0.469214, m=>0.104502)
185+
OrderedDict(s=>0.170969, m=>0.598514) OrderedDict(s=>0.853546, m=>0.185399)
186+
OrderedDict(s=>0.704776, m=>0.322111) … OrderedDict(s=>0.638301, m=>0.853802)
187+
OrderedDict(s=>0.441044, m=>0.162285) OrderedDict(s=>0.852959, m=>0.0956922)
188+
OrderedDict(s=>0.803972, m=>0.643369) OrderedDict(s=>0.245049, m=>0.871985)
189+
OrderedDict(s=>0.772384, m=>0.646323) OrderedDict(s=>0.906603, m=>0.385502)
190+
OrderedDict(s=>0.70882, m=>0.253105) OrderedDict(s=>0.413222, m=>0.953288)
191+
192+
julia> # This can be used to `condition` a `Model`.
193+
conditioned_model = model | first(iter);
194+
195+
julia> conditioned_model() # <= results in same values as the `first(iter)` above
196+
(0.5805148626851955, 0.7393275279160691)
197+
```
198+
"""
199+
function value_iterator_from_chain(model::DynamicPPL.Model, chain)
200+
return value_iterator_from_chain(VarInfo(model), chain)
201+
end
202+
203+
function value_iterator_from_chain(vi::AbstractVarInfo, chain)
204+
return Iterators.map(
205+
Iterators.product(1:size(chain, 1), 1:size(chain, 3))
206+
) do (iteration_idx, chain_idx)
207+
values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}())
208+
end
209+
end

test/turing/model.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,19 @@
99
test_setval!(model, MCMCChains.get_sections(chain, :parameters))
1010
end
1111
end
12+
13+
@testset "value_iterator_from_chain" begin
14+
@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
15+
chain = sample(model, Prior(), 10; progress=false)
16+
for (i, d) in enumerate(value_iterator_from_chain(model, chain))
17+
for vn in keys(d)
18+
val = DynamicPPL.getvalue(d, vn)
19+
for vn_leaf in DynamicPPL.varname_leaves(vn, val)
20+
val_leaf = DynamicPPL.getvalue(d, vn_leaf)
21+
@test val_leaf == chain[i, Symbol(vn_leaf), 1]
22+
end
23+
end
24+
end
25+
end
26+
end
1227
end

0 commit comments

Comments
 (0)