Skip to content

Commit d11a33f

Browse files
authored
Revert "reverted merge with torfjelde/extract-realizations" (#590)
This reverts commit 33a84c7.
1 parent 8c432b6 commit d11a33f

File tree

4 files changed

+208
-1
lines changed

4 files changed

+208
-1
lines changed

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ Sometimes it can be useful to extract the priors of a model. This is the possibl
143143
extract_priors
144144
```
145145

146+
Safe extraction of values from a given [`AbstractVarInfo`](@ref) as they are seen in the model can be done using [`values_as_in_model`](@ref).
147+
148+
```@docs
149+
values_as_in_model
150+
```
151+
146152
```@docs
147153
NamedDist
148154
```

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ export AbstractVarInfo,
9292
getargnames,
9393
generated_quantities,
9494
extract_priors,
95+
values_as_in_model,
9596
# Samplers
9697
Sampler,
9798
SampleFromPrior,
@@ -179,6 +180,7 @@ include("transforming.jl")
179180
include("logdensityfunction.jl")
180181
include("model_utils.jl")
181182
include("extract_priors.jl")
183+
include("values_as_in_model.jl")
182184

183185
include("debug_utils.jl")
184186
using .DebugUtils

src/values_as_in_model.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
2+
"""
3+
ValuesAsInModelContext
4+
5+
A context that is used by [`values_as_in_model`](@ref) to obtain values
6+
of the model parameters as they are in the model.
7+
8+
This is particularly useful when working in unconstrained space, but one
9+
wants to extract the realization of a model in a constrained space.
10+
11+
# Fields
12+
$(TYPEDFIELDS)
13+
"""
14+
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
15+
"values that are extracted from the model"
16+
values::T
17+
"child context"
18+
context::C
19+
end
20+
21+
ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
22+
function ValuesAsInModelContext(context::AbstractContext)
23+
return ValuesAsInModelContext(OrderedDict(), context)
24+
end
25+
26+
NodeTrait(::ValuesAsInModelContext) = IsParent()
27+
childcontext(context::ValuesAsInModelContext) = context.context
28+
function setchildcontext(context::ValuesAsInModelContext, child)
29+
return ValuesAsInModelContext(context.values, child)
30+
end
31+
32+
function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
33+
return setindex!(context.values, copy(value), vn)
34+
end
35+
36+
function broadcast_push!(context::ValuesAsInModelContext, vns, values)
37+
return push!.((context,), vns, values)
38+
end
39+
40+
# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
41+
function broadcast_push!(
42+
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
43+
)
44+
for (vn, col) in zip(vns, eachcol(values))
45+
push!(context, vn, col)
46+
end
47+
end
48+
49+
# `tilde_asssume`
50+
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
51+
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
52+
# Save the value.
53+
push!(context, vn, value)
54+
# Save the value.
55+
# Pass on.
56+
return value, logp, vi
57+
end
58+
function tilde_assume(
59+
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
60+
)
61+
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
62+
# Save the value.
63+
push!(context, vn, value)
64+
# Pass on.
65+
return value, logp, vi
66+
end
67+
68+
# `dot_tilde_assume`
69+
function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi)
70+
value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi)
71+
72+
# Save the value.
73+
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
74+
broadcast_push!(context, _vns, value)
75+
76+
return value, logp, vi
77+
end
78+
function dot_tilde_assume(
79+
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi
80+
)
81+
value, logp, vi = dot_tilde_assume(
82+
rng, childcontext(context), sampler, right, left, vn, vi
83+
)
84+
# Save the value.
85+
_right, _left, _vns = unwrap_right_left_vns(right, left, vn)
86+
broadcast_push!(context, _vns, value)
87+
88+
return value, logp, vi
89+
end
90+
91+
"""
92+
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
93+
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
94+
95+
Get the values of `varinfo` as they would be seen in the model.
96+
97+
If no `varinfo` is provided, then this is effectively the same as
98+
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
99+
100+
More specifically, this method attempts to extract the realization _as seen in the model_.
101+
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
102+
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
103+
space.
104+
105+
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
106+
of additional model evaluations.
107+
108+
# Arguments
109+
- `model::Model`: model to extract realizations from.
110+
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
111+
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
112+
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
113+
114+
# Examples
115+
116+
## When `VarInfo` fails
117+
118+
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
119+
120+
```jldoctest
121+
julia> using Distributions, StableRNGs
122+
123+
julia> rng = StableRNG(42);
124+
125+
julia> @model function model_changing_support()
126+
x ~ Bernoulli(0.5)
127+
y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
128+
end;
129+
130+
julia> model = model_changing_support();
131+
132+
julia> # Construct initial type-stable `VarInfo`.
133+
varinfo = VarInfo(rng, model);
134+
135+
julia> # Link it so it works in unconstrained space.
136+
varinfo_linked = DynamicPPL.link(varinfo, model);
137+
138+
julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
139+
# Flip `x` so we hit the other support of `y`.
140+
θ = [!varinfo[@varname(x)], rand(rng)];
141+
142+
julia> # Update the `VarInfo` with the new values.
143+
varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);
144+
145+
julia> # Determine the expected support of `y`.
146+
lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
147+
(0, 1)
148+
149+
julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
150+
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);
151+
152+
julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
153+
# used in the very first model evaluation, hence the support of `y`
154+
# is not updated even though `x` has changed.
155+
lb ≤ varinfo_invlinked[@varname(y)] ≤ ub
156+
false
157+
158+
julia> # Approach 2: Extract realizations using `values_as_in_model`.
159+
# (✓) `values_as_in_model` will re-run the model and extract
160+
# the correct realization of `y` given the new values of `x`.
161+
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
162+
true
163+
```
164+
"""
165+
function values_as_in_model(
166+
model::Model,
167+
varinfo::AbstractVarInfo=VarInfo(),
168+
context::AbstractContext=DefaultContext(),
169+
)
170+
context = ValuesAsInModelContext(context)
171+
evaluate!!(model, varinfo, context)
172+
return context.values
173+
end
174+
function values_as_in_model(
175+
rng::Random.AbstractRNG,
176+
model::Model,
177+
varinfo::AbstractVarInfo=VarInfo(),
178+
context::AbstractContext=DefaultContext(),
179+
)
180+
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
181+
end

test/model.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
356356
]
357357
@testset "$(model.f)" for model in models_to_test
358358
vns = DynamicPPL.TestUtils.varnames(model)
359-
example_values = DynamicPPL.TestUtils.rand(model)
359+
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
360360
varinfos = filter(
361361
is_typed_varinfo,
362362
DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns),
@@ -375,4 +375,22 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
375375
end
376376
end
377377
end
378+
379+
@testset "values_as_in_model" begin
380+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
381+
vns = DynamicPPL.TestUtils.varnames(model)
382+
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
383+
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
384+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
385+
realizations = values_as_in_model(model, varinfo)
386+
# Ensure that all variables are found.
387+
vns_found = collect(keys(realizations))
388+
@test vns vns_found == vns vns_found
389+
# Ensure that the values are the same.
390+
for vn in vns
391+
@test realizations[vn] == varinfo[vn]
392+
end
393+
end
394+
end
395+
end
378396
end

0 commit comments

Comments
 (0)