Skip to content

Commit 280eaf1

Browse files
committed
update gibbs to add to the src folder
1 parent ac0ce7a commit 280eaf1

File tree

7 files changed

+297
-232
lines changed

7 files changed

+297
-232
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1414
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1515
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1616
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
17-
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1817
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1918
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2019
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -37,9 +36,8 @@ julia = "1.6"
3736
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3837
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3938
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
40-
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
4139
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4240
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4341

4442
[targets]
45-
test = ["FillArrays", "Distributions", "IJulia", "OrderedCollections", "Statistics", "Test"]
43+
test = ["FillArrays", "Distributions", "IJulia", "Statistics", "Test"]

src/AbstractMCMC.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module AbstractMCMC
22

33
using BangBang: BangBang
4-
using Compat
54
using ConsoleProgressMonitor: ConsoleProgressMonitor
65
using LogDensityProblems: LogDensityProblems
76
using LoggingExtras: LoggingExtras
@@ -81,6 +80,10 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr
8180
"""
8281
struct MCMCSerial <: AbstractMCMCEnsemble end
8382

83+
function condition end
84+
85+
function recompute_logprob!! end
86+
8487
"""
8588
get_logprob(state)
8689
@@ -116,5 +119,6 @@ include("sample.jl")
116119
include("stepper.jl")
117120
include("transducer.jl")
118121
include("logdensityproblems.jl")
122+
include("gibbs.jl")
119123

120124
end # module AbstractMCMC

src/gibbs.jl

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""
2+
Gibbs(sampler_map::NamedTuple)
3+
4+
An interface for block sampling in Markov Chain Monte Carlo (MCMC).
5+
6+
Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems.
7+
It allows different sampling methods to be applied to different parameters.
8+
"""
9+
struct Gibbs <: AbstractMCMC.AbstractSampler
10+
sampler_map::NamedTuple
11+
parameter_names::Tuple{Vararg{Symbol}}
12+
13+
function Gibbs(sampler_map::NamedTuple)
14+
parameter_names = Tuple(keys(sampler_map))
15+
return new(sampler_map, parameter_names)
16+
end
17+
end
18+
19+
struct GibbsState
20+
"""
21+
`trace` contains the values of the values of _all_ parameters up to the last iteration.
22+
"""
23+
trace::NamedTuple
24+
25+
"""
26+
`mcmc_states` maps parameters to their sampler-specific MCMC states.
27+
"""
28+
mcmc_states::NamedTuple
29+
30+
"""
31+
`variable_sizes` maps parameters to their sizes.
32+
"""
33+
variable_sizes::NamedTuple
34+
end
35+
36+
struct GibbsTransition
37+
"""
38+
Realizations of the parameters, this is considered a "sample" in the MCMC chain.
39+
"""
40+
values::NamedTuple
41+
end
42+
43+
"""
44+
flatten(trace::Union{NamedTuple,OrderedCollections.OrderedDict})
45+
46+
Flatten all the values in the trace into a single vector.
47+
48+
# Examples
49+
50+
```jldoctest
51+
julia> flatten((a=[1,2], b=[3,4,5]))
52+
[1, 2, 3, 4, 5]
53+
54+
julia> flatten(OrderedCollections.OrderedDict(:x=>[1.0,2.0], :y=>[3.0,4.0,5.0]))
55+
[1.0, 2.0, 3.0, 4.0, 5.0]
56+
```
57+
"""
58+
function flatten(trace::NamedTuple)
59+
return reduce(vcat, vec.(values(trace)))
60+
end
61+
62+
"""
63+
unflatten(vec::AbstractVector, group_names_and_sizes::NamedTuple)
64+
65+
Reverse operation of flatten. Reshape the vector into the original arrays using size information.
66+
67+
# Examples
68+
69+
```jldoctest
70+
julia> unflatten([1,2,3,4,5], (a=(2,), b=(3,)))
71+
(a=[1,2], b=[3,4,5])
72+
73+
julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,)))
74+
(x=[1.0 3.0; 2.0 4.0], y=[5.0,6.0])
75+
```
76+
"""
77+
function unflatten(vec::AbstractVector, variable_sizes::NamedTuple)
78+
result = Dict{Symbol,Array}()
79+
start_idx = 1
80+
for (name, size) in pairs(variable_sizes)
81+
end_idx = start_idx + prod(size) - 1
82+
result[name] = reshape(vec[start_idx:end_idx], size...)
83+
start_idx = end_idx + 1
84+
end
85+
86+
# ensure the order of the keys is the same as the one in variable_sizes
87+
return NamedTuple{Tuple(keys(variable_sizes))}([
88+
result[name] for name in keys(variable_sizes)
89+
])
90+
end
91+
92+
"""
93+
update_trace(trace::NamedTuple, gibbs_state::GibbsState)
94+
95+
Update the trace with the values from the MCMC states of the sub-problems.
96+
"""
97+
function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
98+
for parameter_variable in keys(gibbs_state.mcmc_states)
99+
sub_state = gibbs_state.mcmc_states[parameter_variable]
100+
trace = merge(
101+
trace,
102+
unflatten(
103+
AbstractMCMC.get_params(sub_state),
104+
NamedTuple{(parameter_variable,)}((
105+
gibbs_state.variable_sizes[parameter_variable],
106+
)),
107+
),
108+
)
109+
end
110+
return trace
111+
end
112+
113+
function AbstractMCMC.step(
114+
rng::Random.AbstractRNG,
115+
logdensity_model::AbstractMCMC.LogDensityModel,
116+
sampler::Gibbs,
117+
args...;
118+
initial_params::NamedTuple,
119+
kwargs...,
120+
)
121+
if Set(keys(initial_params)) != Set(sampler.parameter_names)
122+
throw(
123+
ArgumentError(
124+
"initial_params must contain all parameters in the model, expected $(sampler.parameter_names), got $(keys(initial_params))",
125+
),
126+
)
127+
end
128+
129+
mcmc_states = Dict{Symbol,Any}()
130+
variable_sizes = Dict{Symbol,Tuple}()
131+
for parameter_variable in sampler.parameter_names
132+
sub_sampler = sampler.sampler_map[parameter_variable]
133+
134+
variables_to_be_conditioned_on = setdiff(
135+
sampler.parameter_names, (parameter_variable,)
136+
)
137+
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
138+
Tuple([initial_params[g] for g in variables_to_be_conditioned_on])
139+
)
140+
sub_problem_parameters_values = NamedTuple{(parameter_variable,)}((
141+
initial_params[parameter_variable],
142+
))
143+
144+
# LogDensityProblems' `logdensity` function expects a single vector of real numbers
145+
# `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values
146+
# and unflatten after the sampling step
147+
variable_sizes[parameter_variable] = Tuple(size(initial_params[parameter_variable]))
148+
flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values)
149+
150+
sub_state = last(
151+
AbstractMCMC.step(
152+
rng,
153+
AbstractMCMC.LogDensityModel(
154+
AbstractMCMC.condition(
155+
logdensity_model.logdensity, conditioning_variables_values
156+
),
157+
),
158+
sub_sampler,
159+
args...;
160+
initial_params=flattened_sub_problem_parameters_values,
161+
kwargs...,
162+
),
163+
)
164+
mcmc_states[parameter_variable] = sub_state
165+
end
166+
167+
gibbs_state = GibbsState(
168+
initial_params, NamedTuple(mcmc_states), NamedTuple(variable_sizes)
169+
)
170+
trace = update_trace(NamedTuple(), gibbs_state)
171+
return GibbsTransition(trace), gibbs_state
172+
end
173+
174+
function AbstractMCMC.step(
175+
rng::Random.AbstractRNG,
176+
logdensity_model::AbstractMCMC.LogDensityModel,
177+
sampler::Gibbs,
178+
gibbs_state::GibbsState,
179+
args...;
180+
kwargs...,
181+
)
182+
(; trace, mcmc_states, variable_sizes) = gibbs_state
183+
mcmc_states_dict = Dict(
184+
keys(mcmc_states) .=> [mcmc_states[k] for k in keys(mcmc_states)]
185+
)
186+
for parameter_variable in sampler.parameter_names
187+
sub_sampler = sampler.sampler_map[parameter_variable]
188+
sub_state = mcmc_states[parameter_variable]
189+
variables_to_be_conditioned_on = setdiff(
190+
sampler.parameter_names, (parameter_variable,)
191+
)
192+
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
193+
Tuple([trace[g] for g in variables_to_be_conditioned_on])
194+
)
195+
cond_logdensity = AbstractMCMC.condition(
196+
logdensity_model.logdensity, conditioning_variables_values
197+
)
198+
199+
# recompute the logdensity stored in the mcmc state, because the values might have been updated in other sub-problems
200+
sub_state = AbstractMCMC.recompute_logprob!!(
201+
cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state
202+
)
203+
204+
sub_state = last(
205+
AbstractMCMC.step(
206+
rng,
207+
AbstractMCMC.LogDensityModel(cond_logdensity),
208+
sub_sampler,
209+
sub_state,
210+
args...;
211+
kwargs...,
212+
),
213+
)
214+
mcmc_states_dict[parameter_variable] = sub_state
215+
trace = update_trace(trace, gibbs_state)
216+
end
217+
218+
mcmc_states = NamedTuple{Tuple(keys(mcmc_states_dict))}(
219+
Tuple([mcmc_states_dict[k] for k in keys(mcmc_states_dict)])
220+
)
221+
return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes)
222+
end

0 commit comments

Comments
 (0)