1
- struct TrackedValue{T}
2
- value:: T
3
- end
4
-
5
- is_tracked_value (:: TrackedValue ) = true
6
- is_tracked_value (:: Any ) = false
7
-
8
- check_tilde_rhs (x:: TrackedValue ) = x
9
-
10
1
"""
11
- ValuesAsInModelContext
2
+ ValuesAsInModelAccumulator <: AbstractAccumulator
12
3
13
- A context that is used by [`values_as_in_model`](@ref) to obtain values
4
+ An accumulator that is used by [`values_as_in_model`](@ref) to obtain values
14
5
of the model parameters as they are in the model.
15
6
16
7
This is particularly useful when working in unconstrained space, but one
@@ -19,72 +10,47 @@ wants to extract the realization of a model in a constrained space.
19
10
# Fields
20
11
$(TYPEDFIELDS)
21
12
"""
22
- struct ValuesAsInModelContext{C <: AbstractContext } <: AbstractContext
13
+ struct ValuesAsInModelAccumulator <: AbstractAccumulator
23
14
" values that are extracted from the model"
24
15
values:: OrderedDict
25
16
" whether to extract variables on the LHS of :="
26
17
include_colon_eq:: Bool
27
- " child context"
28
- context:: C
29
18
end
30
- function ValuesAsInModelContext (include_colon_eq, context :: AbstractContext )
31
- return ValuesAsInModelContext (OrderedDict (), include_colon_eq, context )
19
+ function ValuesAsInModelAccumulator (include_colon_eq)
20
+ return ValuesAsInModelAccumulator (OrderedDict (), include_colon_eq)
32
21
end
33
22
34
- NodeTrait (:: ValuesAsInModelContext ) = IsParent ()
35
- childcontext (context:: ValuesAsInModelContext ) = context. context
36
- function setchildcontext (context:: ValuesAsInModelContext , child)
37
- return ValuesAsInModelContext (context. values, context. include_colon_eq, child)
38
- end
23
+ accumulator_name (:: Type{<:ValuesAsInModelAccumulator} ) = :ValuesAsInModel
39
24
40
- is_extracting_values (context:: ValuesAsInModelContext ) = context. include_colon_eq
41
- function is_extracting_values (context:: AbstractContext )
42
- return is_extracting_values (NodeTrait (context), context)
25
+ function split (acc:: ValuesAsInModelAccumulator )
26
+ return ValuesAsInModelAccumulator (empty (acc. values), acc. include_colon_eq)
43
27
end
44
- is_extracting_values (:: IsParent , :: AbstractContext ) = false
45
- is_extracting_values (:: IsLeaf , :: AbstractContext ) = false
46
-
47
- function Base. push! (context:: ValuesAsInModelContext , vn:: VarName , value)
48
- return setindex! (context. values, copy (value), prefix (context, vn))
28
+ function combine (acc1:: ValuesAsInModelAccumulator , acc2:: ValuesAsInModelAccumulator )
29
+ if acc1. include_colon_eq != acc2. include_colon_eq
30
+ msg = " Cannot combine accumulators with different include_colon_eq values."
31
+ throw (ArgumentError (msg))
32
+ end
33
+ return ValuesAsInModelAccumulator (
34
+ merge (acc1. values, acc2. values), acc1. include_colon_eq
35
+ )
49
36
end
50
37
51
- function broadcast_push! (context:: ValuesAsInModelContext , vns, values)
52
- return push! .((context,), vns, values)
38
+ function Base. push! (acc:: ValuesAsInModelAccumulator , vn:: VarName , val)
39
+ setindex! (acc. values, deepcopy (val), vn)
40
+ return acc
53
41
end
54
42
55
- # This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
56
- function broadcast_push! (
57
- context:: ValuesAsInModelContext , vns:: AbstractVector , values:: AbstractMatrix
58
- )
59
- for (vn, col) in zip (vns, eachcol (values))
60
- push! (context, vn, col)
61
- end
43
+ function is_extracting_values (vi:: AbstractVarInfo )
44
+ return hasacc (vi, Val (:ValuesAsInModel )) &&
45
+ getacc (vi, Val (:ValuesAsInModel )). include_colon_eq
62
46
end
63
47
64
- # `tilde_asssume`
65
- function tilde_assume (context:: ValuesAsInModelContext , right, vn, vi)
66
- if is_tracked_value (right)
67
- value = right. value
68
- else
69
- value, vi = tilde_assume (childcontext (context), right, vn, vi)
70
- end
71
- push! (context, vn, value)
72
- return value, vi
73
- end
74
- function tilde_assume (
75
- rng:: Random.AbstractRNG , context:: ValuesAsInModelContext , sampler, right, vn, vi
76
- )
77
- if is_tracked_value (right)
78
- value = right. value
79
- else
80
- value, vi = tilde_assume (rng, childcontext (context), sampler, right, vn, vi)
81
- end
82
- # Save the value.
83
- push! (context, vn, value)
84
- # Pass on.
85
- return value, vi
48
+ function accumulate_assume!! (acc:: ValuesAsInModelAccumulator , val, logjac, vn, right)
49
+ return push! (acc, vn, val)
86
50
end
87
51
52
+ accumulate_observe!! (acc:: ValuesAsInModelAccumulator , right, left, vn) = acc
53
+
88
54
"""
89
55
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
90
56
@@ -103,7 +69,7 @@ space at the cost of additional model evaluations.
103
69
- `model::Model`: model to extract realizations from.
104
70
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
105
71
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
106
- - `context::AbstractContext`: base context to use for the extraction. Defaults
72
+ - `context::AbstractContext`: evaluation context to use in the extraction. Defaults
107
73
to `DynamicPPL.DefaultContext()`.
108
74
109
75
# Examples
@@ -164,7 +130,8 @@ function values_as_in_model(
164
130
varinfo:: AbstractVarInfo ,
165
131
context:: AbstractContext = DefaultContext (),
166
132
)
167
- context = ValuesAsInModelContext (include_colon_eq, context)
168
- evaluate!! (model, varinfo, context)
169
- return context. values
133
+ accs = getaccs (varinfo)
134
+ varinfo = setaccs!! (deepcopy (varinfo), (ValuesAsInModelAccumulator (include_colon_eq),))
135
+ varinfo = last (evaluate!! (model, varinfo, context))
136
+ return getacc (varinfo, Val (:ValuesAsInModel )). values
170
137
end
0 commit comments