-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathvalues_as_in_model.jl
More file actions
147 lines (115 loc) · 5.58 KB
/
values_as_in_model.jl
File metadata and controls
147 lines (115 loc) · 5.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
ValuesAsInModelAccumulator <: AbstractAccumulator
An accumulator that is used by [`values_as_in_model`](@ref) to obtain values
of the model parameters as they are in the model.
This is particularly useful when working in unconstrained space, but one
wants to extract the realization of a model in a constrained space.
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelAccumulator{VNT<:VarNamedTuple} <: AbstractAccumulator
"values that are extracted from the model"
values::VNT
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
end
function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(VarNamedTuple(), include_colon_eq)
end
function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
return (acc1.include_colon_eq == acc2.include_colon_eq && acc1.values == acc2.values)
end
function Base.copy(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq)
end
accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel
# TODO(mhauru) We could start using reset!!, which could call empty!! on the VarNamedTuple.
# This would create VarNamedTuples that share memory with the original one, saving
# allocations but also making them not capable of taking in any arbitrary VarName.
function _zero(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
end
reset(acc::ValuesAsInModelAccumulator) = _zero(acc)
split(acc::ValuesAsInModelAccumulator) = _zero(acc)
function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
if acc1.include_colon_eq != acc2.include_colon_eq
msg = "Cannot combine accumulators with different include_colon_eq values."
throw(ArgumentError(msg))
end
return ValuesAsInModelAccumulator(
merge(acc1.values, acc2.values), acc1.include_colon_eq
)
end
function BangBang.push!!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
# TODO(mhauru) The deepcopy here is quite unfortunate. It is needed so that the model
# body can go mutating the object without that reactively affecting the value in the
# accumulator, which should be as it was at `~` time. Could there be a way around this?
Accessors.@reset acc.values = setindex!!(acc.values, deepcopy(val), vn)
return acc
end
function is_extracting_values(vi::AbstractVarInfo)
return hasacc(vi, Val(:ValuesAsInModel)) &&
getacc(vi, Val(:ValuesAsInModel)).include_colon_eq
end
function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right)
return push!!(acc, vn, val)
end
accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc
"""
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo)
Get the values of `varinfo` as they would be seen in the model.
More specifically, this method attempts to extract the realization _as seen in
the model_. For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a
realization that is compatible with `truncated(Normal(); lower=0)` -- i.e. one
where the value of `x[1]` is positive -- regardless of whether `varinfo` is
working in unconstrained space.
Hence this method is a "safe" way of obtaining realizations in constrained
space at the cost of additional model evaluations.
Returns a `VarNamedTuple`.
# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
# Examples
## When `VarInfo` fails
The following demonstrates a common pitfall when working with [`VarInfo`](@ref)
and constrained variables.
```jldoctest
julia> using Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_changing_support()
x ~ Bernoulli(0.5)
y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
end;
julia> model = model_changing_support();
julia> # Construct initial type-stable `VarInfo`.
varinfo = VarInfo(rng, model);
julia> # Link it so it works in unconstrained space.
varinfo_linked = DynamicPPL.link(varinfo, model);
julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
# Flip `x` so we hit the other support of `y`.
θ = [!varinfo[@varname(x)], rand(rng)];
julia> # Update the `VarInfo` with the new values.
varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);
julia> # Determine the expected support of `y`.
lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
(0, 1)
julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);
julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
# used in the very first model evaluation, hence the support of `y`
# is not updated even though `x` has changed.
lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub
false
julia> # Approach 2: Extract realizations using `values_as_in_model`.
# (✓) `values_as_in_model` will re-run the model and extract
# the correct realization of `y` given the new values of `x`.
lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub
true
```
"""
function values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo)
varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),))
varinfo = last(evaluate!!(model, varinfo))
return getacc(varinfo, Val(:ValuesAsInModel)).values
end