Skip to content

Commit 7f889cf

Browse files
committed
rename gibbs test file to prepare for moving
1 parent 39c4d87 commit 7f889cf

File tree

2 files changed

+34
-35
lines changed

2 files changed

+34
-35
lines changed

src/gibbs.jl

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,50 @@ An interface for block sampling in Markov Chain Monte Carlo (MCMC).
66
Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems.
77
It allows different sampling methods to be applied to different parameters.
88
"""
9-
struct Gibbs{NT} <: AbstractMCMC.AbstractSampler
9+
struct Gibbs{NT<:NamedTuple} <: AbstractMCMC.AbstractSampler
1010
sampler_map::NT
1111
end
1212

13-
struct GibbsState
13+
struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple}
1414
"""
15-
`trace` contains the values of the values of _all_ parameters up to the last iteration.
15+
Contains the values of all parameters up to the last iteration.
1616
"""
17-
trace::NamedTuple
17+
trace::TraceNT
1818

1919
"""
20-
`mcmc_states` maps parameters to their sampler-specific MCMC states.
20+
Maps parameters to their sampler-specific MCMC states.
2121
"""
22-
mcmc_states::NamedTuple
22+
mcmc_states::StateNT
2323

2424
"""
25-
`variable_sizes` maps parameters to their sizes.
25+
Maps parameters to their sizes.
2626
"""
27-
variable_sizes::NamedTuple
27+
variable_sizes::SizeNT
2828
end
2929

30-
struct GibbsTransition
30+
struct GibbsTransition{ValuesNT<:NamedTuple}
3131
"""
3232
Realizations of the parameters, this is considered a "sample" in the MCMC chain.
3333
"""
34-
values::NamedTuple
34+
values::ValuesNT
3535
end
3636

3737
"""
38-
flatten(trace::Union{NamedTuple,OrderedCollections.OrderedDict})
38+
flatten(trace::NamedTuple)
3939
40-
Flatten all the values in the trace into a single vector.
40+
Flatten all the values in the trace into a single vector. Variable names information is discarded.
4141
4242
# Examples
4343
4444
```jldoctest; setup = :(using AbstractMCMC: flatten)
45-
julia> flatten((a=[1,2], b=[3,4,5]))
46-
5-element Vector{Int64}:
47-
1
48-
2
49-
3
50-
4
51-
5
45+
julia> flatten((a=ones(2), b=ones(2, 2)))
46+
6-element Vector{Float64}:
47+
1.0
48+
1.0
49+
1.0
50+
1.0
51+
1.0
52+
1.0
5253
5354
```
5455
"""
@@ -57,7 +58,7 @@ function flatten(trace::NamedTuple)
5758
end
5859

5960
"""
60-
unflatten(vec::AbstractVector, group_names_and_sizes::NamedTuple)
61+
unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple})
6162
6263
Reverse operation of flatten. Reshape the vector into the original arrays using size information.
6364
@@ -71,20 +72,19 @@ julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,)))
7172
(x = [1.0 3.0; 2.0 4.0], y = [5.0, 6.0])
7273
```
7374
"""
74-
function unflatten(vec::AbstractVector, variable_sizes::NamedTuple)
75+
function unflatten(
76+
vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names}
77+
) where {variable_names}
7578
result = Dict{Symbol,Array}()
7679
start_idx = 1
77-
for name in keys(variable_sizes)
78-
size = variable_sizes[name]
80+
for name in variable_names
81+
size = variable_names_and_sizes[name]
7982
end_idx = start_idx + prod(size) - 1
8083
result[name] = reshape(vec[start_idx:end_idx], size...)
8184
start_idx = end_idx + 1
8285
end
8386

84-
# ensure the order of the keys is the same as the one in variable_sizes
85-
return NamedTuple{Tuple(keys(variable_sizes))}([
86-
result[name] for name in keys(variable_sizes)
87-
])
87+
return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names]))
8888
end
8989

9090
"""
@@ -95,15 +95,14 @@ Update the trace with the values from the MCMC states of the sub-problems.
9595
function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
9696
for parameter_variable in keys(gibbs_state.mcmc_states)
9797
sub_state = gibbs_state.mcmc_states[parameter_variable]
98-
trace = merge(
99-
trace,
100-
unflatten(
101-
vec(sub_state),
102-
NamedTuple{(parameter_variable,)}((
103-
gibbs_state.variable_sizes[parameter_variable],
104-
)),
105-
),
98+
sub_state_params = vec(sub_state)
99+
unflattened_sub_state_params = unflatten(
100+
sub_state_params,
101+
NamedTuple{(parameter_variable,)}((
102+
gibbs_state.variable_sizes[parameter_variable],
103+
)),
106104
)
105+
trace = merge(trace, unflattened_sub_state_params)
107106
end
108107
return trace
109108
end
File renamed without changes.

0 commit comments

Comments
 (0)