@@ -6,49 +6,50 @@ An interface for block sampling in Markov Chain Monte Carlo (MCMC).
6
6
Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems.
7
7
It allows different sampling methods to be applied to different parameters.
8
8
"""
9
- struct Gibbs{NT} <: AbstractMCMC.AbstractSampler
9
+ struct Gibbs{NT<: NamedTuple } <: AbstractMCMC.AbstractSampler
10
10
sampler_map:: NT
11
11
end
12
12
13
- struct GibbsState
13
+ struct GibbsState{TraceNT <: NamedTuple ,StateNT <: NamedTuple ,SizeNT <: NamedTuple }
14
14
"""
15
- `trace` contains the values of the values of _all_ parameters up to the last iteration.
15
+ Contains the values of all parameters up to the last iteration.
16
16
"""
17
- trace:: NamedTuple
17
+ trace:: TraceNT
18
18
19
19
"""
20
- `mcmc_states` maps parameters to their sampler-specific MCMC states.
20
+ Maps parameters to their sampler-specific MCMC states.
21
21
"""
22
- mcmc_states:: NamedTuple
22
+ mcmc_states:: StateNT
23
23
24
24
"""
25
- `variable_sizes` maps parameters to their sizes.
25
+ Maps parameters to their sizes.
26
26
"""
27
- variable_sizes:: NamedTuple
27
+ variable_sizes:: SizeNT
28
28
end
29
29
30
- struct GibbsTransition
30
+ struct GibbsTransition{ValuesNT <: NamedTuple }
31
31
"""
32
32
Realizations of the parameters, this is considered a "sample" in the MCMC chain.
33
33
"""
34
- values:: NamedTuple
34
+ values:: ValuesNT
35
35
end
36
36
37
37
"""
38
- flatten(trace::Union{ NamedTuple,OrderedCollections.OrderedDict} )
38
+ flatten(trace::NamedTuple)
39
39
40
- Flatten all the values in the trace into a single vector.
40
+ Flatten all the values in the trace into a single vector. Variable names information is discarded.
41
41
42
42
# Examples
43
43
44
44
```jldoctest; setup = :(using AbstractMCMC: flatten)
45
- julia> flatten((a=[1,2], b=[3,4,5]))
46
- 5-element Vector{Int64}:
47
- 1
48
- 2
49
- 3
50
- 4
51
- 5
45
+ julia> flatten((a=ones(2), b=ones(2, 2)))
46
+ 6-element Vector{Float64}:
47
+ 1.0
48
+ 1.0
49
+ 1.0
50
+ 1.0
51
+ 1.0
52
+ 1.0
52
53
53
54
```
54
55
"""
@@ -57,7 +58,7 @@ function flatten(trace::NamedTuple)
57
58
end
58
59
59
60
"""
60
- unflatten(vec::AbstractVector, group_names_and_sizes::NamedTuple )
61
+ unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple} )
61
62
62
63
Reverse operation of flatten. Reshape the vector into the original arrays using size information.
63
64
@@ -71,20 +72,19 @@ julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,)))
71
72
(x = [1.0 3.0; 2.0 4.0], y = [5.0, 6.0])
72
73
```
73
74
"""
74
- function unflatten (vec:: AbstractVector , variable_sizes:: NamedTuple )
75
+ function unflatten (
76
+ vec:: AbstractVector , variable_names_and_sizes:: NamedTuple{variable_names}
77
+ ) where {variable_names}
75
78
result = Dict {Symbol,Array} ()
76
79
start_idx = 1
77
- for name in keys (variable_sizes)
78
- size = variable_sizes [name]
80
+ for name in variable_names
81
+ size = variable_names_and_sizes [name]
79
82
end_idx = start_idx + prod (size) - 1
80
83
result[name] = reshape (vec[start_idx: end_idx], size... )
81
84
start_idx = end_idx + 1
82
85
end
83
86
84
- # ensure the order of the keys is the same as the one in variable_sizes
85
- return NamedTuple {Tuple(keys(variable_sizes))} ([
86
- result[name] for name in keys (variable_sizes)
87
- ])
87
+ return NamedTuple {variable_names} (Tuple ([result[name] for name in variable_names]))
88
88
end
89
89
90
90
"""
@@ -95,15 +95,14 @@ Update the trace with the values from the MCMC states of the sub-problems.
95
95
function update_trace (trace:: NamedTuple , gibbs_state:: GibbsState )
96
96
for parameter_variable in keys (gibbs_state. mcmc_states)
97
97
sub_state = gibbs_state. mcmc_states[parameter_variable]
98
- trace = merge (
99
- trace,
100
- unflatten (
101
- vec (sub_state),
102
- NamedTuple {(parameter_variable,)} ((
103
- gibbs_state. variable_sizes[parameter_variable],
104
- )),
105
- ),
98
+ sub_state_params = vec (sub_state)
99
+ unflattened_sub_state_params = unflatten (
100
+ sub_state_params,
101
+ NamedTuple {(parameter_variable,)} ((
102
+ gibbs_state. variable_sizes[parameter_variable],
103
+ )),
106
104
)
105
+ trace = merge (trace, unflattened_sub_state_params)
107
106
end
108
107
return trace
109
108
end
0 commit comments