Skip to content

Commit e8172f0

Browse files
JaimeRZPgithub-actions[bot]devmotiontorfjelde
authored
PriorExtractorContext (#496)
* first commt * export context * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * missing comma * Update src/contexts.jl Co-authored-by: David Widmann <[email protected]> * fixed compilation * extract_priors * tests * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * tests for dot * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * bug * Apply suggestions from code review * added docstring to extract_priors * fixed and added more tests for extract_priors * moved prior extraction to a separate file * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * forgot to move a small piece of code * added extract_priors to docs * Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added qualifiers to docstring of extract_priors * Revert "added qualifiers to docstring of extract_priors" This reverts commit cab9f9c. * "fixed" the doctests as ran in docs making * make calls to doctest consistent * Update test/runtests.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent e6dd4ef commit e8172f0

File tree

6 files changed

+168
-1
lines changed

6 files changed

+168
-1
lines changed

docs/make.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
using Documenter
22
using DynamicPPL
33
using DynamicPPL: AbstractPPL
4+
# NOTE: This is necessary to ensure that if we print something from
5+
# Distributions.jl in a doctest, then the shown value will not include
6+
# a qualifier; that is, we don't want `Distributions.Normal{Float64}`
7+
# but rather `Normal{Float64}`. The latter is what will then be printed
8+
# in the doctest as run in `test/runtests.jl`, and so we need to stay
9+
# consistent with that.
10+
using Distributions
411

512
# Doctest setup
613
DocMeta.setdocmeta!(

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob
102102
pointwise_loglikelihoods
103103
```
104104

105+
Sometimes it can be useful to extract the priors of a model. This is the possible using [`extract_priors`](@ref).
106+
107+
```@docs
108+
extract_priors
109+
```
110+
105111
```@docs
106112
NamedDist
107113
```

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ export AbstractVarInfo,
8686
getmissings,
8787
getargnames,
8888
generated_quantities,
89+
extract_priors,
8990
# Samplers
9091
Sampler,
9192
SampleFromPrior,
@@ -166,5 +167,6 @@ include("submodel_macro.jl")
166167
include("test_utils.jl")
167168
include("transforming.jl")
168169
include("logdensityfunction.jl")
170+
include("extract_priors.jl")
169171

170172
end # module

src/extract_priors.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
struct PriorExtractorContext{D<:OrderedDict{VarName,Any},Ctx<:AbstractContext} <:
2+
AbstractContext
3+
priors::D
4+
context::Ctx
5+
end
6+
7+
PriorExtractorContext(context) = PriorExtractorContext(OrderedDict{VarName,Any}(), context)
8+
9+
NodeTrait(::PriorExtractorContext) = IsParent()
10+
childcontext(context::PriorExtractorContext) = context.context
11+
function setchildcontext(parent::PriorExtractorContext, child)
12+
return PriorExtractorContext(parent.priors, child)
13+
end
14+
15+
function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution)
16+
return context.priors[vn] = dist
17+
end
18+
19+
function setprior!(
20+
context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution
21+
)
22+
for vn in vns
23+
context.priors[vn] = dist
24+
end
25+
end
26+
27+
function setprior!(
28+
context::PriorExtractorContext,
29+
vns::AbstractArray{<:VarName},
30+
dists::AbstractArray{<:Distribution},
31+
)
32+
for (vn, dist) in zip(vns, dists)
33+
context.priors[vn] = dist
34+
end
35+
end
36+
37+
function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi)
38+
setprior!(context, vn, right)
39+
return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
40+
end
41+
42+
function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi)
43+
setprior!(context, vn, right)
44+
return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi)
45+
end
46+
47+
"""
48+
extract_priors([rng::Random.AbstractRNG, ]model::Model)
49+
50+
Extract the priors from a model.
51+
52+
This is done by sampling from the model and
53+
recording the distributions that are used to generate the samples.
54+
55+
!!! warning
56+
Because the extraction is done by execution of the model, there
57+
are several caveats:
58+
59+
1. If one variable, say, `y ~ Normal(0, x)`, where `x ~ Normal()`
60+
is also a random variable, then the extracted prior will have
61+
different parameters in every extraction!
62+
2. If the model does _not_ have static support, say,
63+
`n ~ Categorical(1:10); x ~ MvNormmal(zeros(n), I)`, then the
64+
extracted priors themselves will be different between extractions,
65+
not just their parameters.
66+
67+
Both of these caveats are demonstrated below.
68+
69+
# Examples
70+
71+
## Changing parameters
72+
73+
```jldoctest
74+
julia> using Distributions, StableRNGs
75+
76+
julia> rng = StableRNG(42);
77+
78+
julia> @model function model_dynamic_parameters()
79+
x ~ Normal(0, 1)
80+
y ~ Normal(x, 1)
81+
end;
82+
83+
julia> model = model_dynamic_parameters();
84+
85+
julia> extract_priors(rng, model)[@varname(y)]
86+
Normal{Float64}(μ=-0.6702516921145671, σ=1.0)
87+
88+
julia> extract_priors(rng, model)[@varname(y)]
89+
Normal{Float64}(μ=1.3736306979834252, σ=1.0)
90+
```
91+
92+
## Changing support
93+
94+
```jldoctest
95+
julia> using LinearAlgebra, Distributions, StableRNGs
96+
97+
julia> rng = StableRNG(42);
98+
99+
julia> @model function model_dynamic_support()
100+
n ~ Categorical(ones(10) ./ 10)
101+
x ~ MvNormal(zeros(n), I)
102+
end;
103+
104+
julia> model = model_dynamic_support();
105+
106+
julia> length(extract_priors(rng, model)[@varname(x)])
107+
6
108+
109+
julia> length(extract_priors(rng, model)[@varname(x)])
110+
9
111+
```
112+
"""
113+
extract_priors(model::Model) = extract_priors(Random.default_rng(), model)
114+
function extract_priors(rng::Random.AbstractRNG, model::Model)
115+
context = PriorExtractorContext(SamplingContext(rng))
116+
evaluate!!(model, VarInfo(), context)
117+
return context.priors
118+
end

test/model.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@ struct MyZeroModel end
1212
return x ~ Normal(m, 1)
1313
end
1414

15+
innermost_distribution_type(d::Distribution) = typeof(d)
16+
function innermost_distribution_type(d::Distributions.ReshapedDistribution)
17+
return innermost_distribution_type(d.dist)
18+
end
19+
function innermost_distribution_type(d::Distributions.Product)
20+
dists = map(innermost_distribution_type, d.v)
21+
if any(!=(dists[1]), dists)
22+
error("Cannot extract innermost distribution type from $d")
23+
end
24+
25+
return dists[1]
26+
end
27+
1528
@testset "model.jl" begin
1629
@testset "convenience functions" begin
1730
model = gdemo_default
@@ -154,4 +167,22 @@ end
154167
@model test_defaults(x, n=length(x)) = x ~ MvNormal(zeros(n), I)
155168
@test length(test_defaults(missing, 2)()) == 2
156169
end
170+
171+
@testset "extract priors" begin
172+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
173+
priors = extract_priors(model)
174+
175+
# We know that any variable starting with `s` should have `InverseGamma`
176+
# and any variable starting with `m` should have `Normal`.
177+
for (vn, prior) in priors
178+
if DynamicPPL.getsym(vn) == :s
179+
@test innermost_distribution_type(prior) <: InverseGamma
180+
elseif DynamicPPL.getsym(vn) == :m
181+
@test innermost_distribution_type(prior) <: Union{Normal,MvNormal}
182+
else
183+
error("Unexpected variable name: $vn")
184+
end
185+
end
186+
end
187+
end
157188
end

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ include("test_util.jl")
6060

6161
@testset "doctests" begin
6262
DocMeta.setdocmeta!(
63-
DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true
63+
DynamicPPL,
64+
:DocTestSetup,
65+
:(using DynamicPPL, Distributions);
66+
recursive=true,
6467
)
6568
doctestfilters = [
6669
# Older versions will show "0 element Array" instead of "Type[]".

0 commit comments

Comments
 (0)