Skip to content

Commit fd472df

Browse files
committed
rework the code; still not type stable
1 parent af208bc commit fd472df

File tree

6 files changed

+69
-41
lines changed

6 files changed

+69
-41
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ desc = "A lightweight interface for common MCMC methods."
66
version = "5.3.0"
77

88
[deps]
9-
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
109
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1110
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
1211
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -28,18 +27,21 @@ ConsoleProgressMonitor = "0.1"
2827
FillArrays = "1"
2928
LogDensityProblems = "2"
3029
LoggingExtras = "0.4, 0.5, 1"
30+
MCMCChains = "6"
3131
ProgressLogging = "0.1"
3232
StatsBase = "0.32, 0.33, 0.34"
3333
TerminalLoggers = "0.1"
3434
Transducers = "0.4.30"
3535
julia = "1.6"
3636

3737
[extras]
38+
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
3839
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3940
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4041
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
42+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
4143
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4244
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4345

4446
[targets]
45-
test = ["AbstractPPL","FillArrays", "Distributions", "IJulia", "Statistics", "Test"]
47+
test = ["AbstractPPL","FillArrays", "Distributions", "IJulia", "MCMCChains", "Statistics", "Test"]

src/AbstractMCMC.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,5 @@ include("sample.jl")
8787
include("stepper.jl")
8888
include("transducer.jl")
8989
include("logdensityproblems.jl")
90-
include("gibbs.jl")
9190

9291
end # module AbstractMCMC

test/gibbs_example/gibbs.jl

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using AbstractMCMC, AbstractPPL
2-
using BangBang.ConstructorBase: ConstructorBase
1+
using AbstractMCMC: AbstractMCMC
2+
using AbstractPPL: AbstractPPL
3+
using MCMCChains: Chains
4+
using Random
35

46
"""
57
Gibbs(sampler_map::NamedTuple)
@@ -99,27 +101,34 @@ function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
99101
return trace
100102
end
101103

104+
function error_if_not_fully_initialized(
105+
initial_params::NamedTuple{ParamNames}, sampler::Gibbs{<:NamedTuple{SamplerNames}}
106+
) where {ParamNames,SamplerNames}
107+
if Set(ParamNames) != Set(SamplerNames)
108+
throw(
109+
ArgumentError(
110+
"initial_params must contain all parameters in the model, expected $(SamplerNames), got $(ParamNames)",
111+
),
112+
)
113+
end
114+
end
115+
102116
function AbstractMCMC.step(
103117
rng::Random.AbstractRNG,
104118
logdensity_model::AbstractMCMC.LogDensityModel,
105-
sampler::Gibbs,
119+
sampler::Gibbs{Tsamplingmap},
106120
args...;
107121
initial_params::NamedTuple,
108122
kwargs...,
109-
)
110-
if Set(keys(initial_params)) != Set(keys(sampler.sampler_map))
111-
throw(
112-
ArgumentError(
113-
"initial_params must contain all parameters in the model, expected $(keys(sampler.sampler_map)), got $(keys(initial_params))",
114-
),
115-
)
116-
end
123+
) where {Tsamplingmap}
124+
error_if_not_fully_initialized(initial_params, sampler)
117125

118-
mcmc_states, variable_sizes = map(keys(sampler.sampler_map)) do parameter_variable
126+
model_parameter_names = fieldnames(Tsamplingmap)
127+
results = map(model_parameter_names) do parameter_variable
119128
sub_sampler = sampler.sampler_map[parameter_variable]
120129

121130
variables_to_be_conditioned_on = setdiff(
122-
keys(sampler.sampler_map), (parameter_variable,)
131+
model_parameter_names, (parameter_variable,)
123132
)
124133
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
125134
Tuple([initial_params[g] for g in variables_to_be_conditioned_on])
@@ -137,7 +146,7 @@ function AbstractMCMC.step(
137146
AbstractMCMC.step(
138147
rng,
139148
AbstractMCMC.LogDensityModel(
140-
AbstractMCMC.condition(
149+
AbstractPPL.condition(
141150
logdensity_model.logdensity, conditioning_variables_values
142151
),
143152
),
@@ -150,40 +159,46 @@ function AbstractMCMC.step(
150159
(sub_state, Tuple(size(initial_params[parameter_variable])))
151160
end
152161

162+
mcmc_states = first.(results)
163+
variable_sizes = last.(results)
164+
153165
gibbs_state = GibbsState(
154166
initial_params,
155-
NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states),
156-
NamedTuple{Tuple(keys(sampler.sampler_map))}(variable_sizes),
167+
NamedTuple{Tuple(model_parameter_names)}(mcmc_states),
168+
NamedTuple{Tuple(model_parameter_names)}(variable_sizes),
157169
)
170+
158171
trace = update_trace(NamedTuple(), gibbs_state)
159172
return GibbsTransition(trace), gibbs_state
160173
end
161174

175+
# subsequent steps
162176
function AbstractMCMC.step(
163177
rng::Random.AbstractRNG,
164178
logdensity_model::AbstractMCMC.LogDensityModel,
165-
sampler::Gibbs,
179+
sampler::Gibbs{Tsamplingmap},
166180
gibbs_state::GibbsState,
167181
args...;
168182
kwargs...,
169-
)
183+
) where {Tsamplingmap}
170184
(; trace, mcmc_states, variable_sizes) = gibbs_state
171185

172-
mcmc_states = map(keys(sampler.sampler_map)) do parameter_variable
186+
model_parameter_names = fieldnames(Tsamplingmap)
187+
mcmc_states = map(model_parameter_names) do parameter_variable
173188
sub_sampler = sampler.sampler_map[parameter_variable]
174189
sub_state = mcmc_states[parameter_variable]
175190
variables_to_be_conditioned_on = setdiff(
176-
sampler.parameter_names, (parameter_variable,)
191+
model_parameter_names, (parameter_variable,)
177192
)
178193
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
179194
Tuple([trace[g] for g in variables_to_be_conditioned_on])
180195
)
181-
cond_logdensity = AbstractMCMC.condition(
196+
cond_logdensity = AbstractPPL.condition(
182197
logdensity_model.logdensity, conditioning_variables_values
183198
)
184199

185-
logp = LogDensityProblems.logdensity_and_state(cond_logdensity, sub_state)
186-
sub_state = constructorof(typeof(sub_state))(; logp=logp)
200+
logp = LogDensityProblems.logdensity(cond_logdensity, sub_state)
201+
sub_state = (sub_state)(logp)
187202
sub_state = last(
188203
AbstractMCMC.step(
189204
rng,
@@ -197,7 +212,7 @@ function AbstractMCMC.step(
197212
trace = update_trace(trace, gibbs_state)
198213
sub_state
199214
end
200-
mcmc_states = NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states)
215+
mcmc_states = NamedTuple{Tuple(model_parameter_names)}(mcmc_states)
201216

202217
return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes)
203218
end

test/gibbs_example/hier_normal.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using AbstractPPL: AbstractPPL
2+
13
abstract type AbstractHierNormal end
24

35
struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal
@@ -7,6 +9,8 @@ end
79
struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <:
810
AbstractHierNormal
911
data::Tdata
12+
13+
" The variable to be conditioned on and its value"
1014
conditioned_values::Tconditioned_vars
1115
end
1216

@@ -36,14 +40,15 @@ function log_joint(; mu, tau2, x)
3640
return logp
3741
end
3842

39-
function AbstractMCMC.condition(hn::HierNormal, conditioned_values::NamedTuple)
43+
function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple)
4044
return ConditionedHierNormal(hn.data, conditioned_values)
4145
end
4246

4347
function LogDensityProblems.logdensity(
44-
hier_normal_model::ConditionedHierNormal{names}, params::AbstractVector
45-
) where {names}
46-
variable_to_condition = only(names)
48+
hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars},
49+
params::AbstractVector,
50+
) where {Tdata,Tconditioned_vars}
51+
variable_to_condition = only(fieldnames(Tconditioned_vars))
4752
data = hier_normal_model.data
4853
conditioned_values = hier_normal_model.conditioned_values
4954

test/gibbs_example/mh.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
using AbstractMCMC: AbstractMCMC, LogDensityProblems
12
using Distributions
2-
3+
using Random
34
abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end
45

56
struct MHState{T}
67
params::Vector{T}
78
logp::Float64
89
end
910

11+
# Interface 3: (state::MHState)(logp::Float64)
12+
# This function allows the state to be updated with a new log probability.
13+
# ! this makes state into a Julia functor
14+
(state::MHState)(logp::Float64) = MHState(state.params, logp)
15+
1016
struct MHTransition{T}
1117
params::Vector{T}
1218
end
@@ -15,7 +21,7 @@ end
1521
# This function takes the logdensity function and the state (state is defined by the sampler package)
1622
# and returns the logdensity. It allows for optional recomputation of the log probability.
1723
# If recomputation is not needed, it returns the stored log probability from the state.
18-
function AbstractMCMC.logdensity_and_state(
24+
function LogDensityProblems.logdensity(
1925
logdensity_function, state::MHState; recompute_logp=true
2026
)
2127
return if recompute_logp
@@ -28,9 +34,7 @@ end
2834
# Interface 2: Base.vec
2935
# This function takes a state and returns a vector of the parameter values stored in the state.
3036
# It is part of the interface for interacting with the state object.
31-
function Base.vec(state::MHState)
32-
return state.params
33-
end
37+
Base.vec(state::MHState) = state.params
3438

3539
"""
3640
RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
@@ -62,14 +66,17 @@ function AbstractMCMC.step(
6266
)
6367
logdensity_function = logdensity_model.logdensity
6468
transition = MHTransition(initial_params)
65-
state = MHState(initial_params, only(logdensity_function(initial_params)))
69+
state = MHState(
70+
initial_params,
71+
only(LogDensityProblems.logdensity(logdensity_function, initial_params)),
72+
)
6673

6774
return transition, state
6875
end
6976

70-
@inline proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) =
77+
@inline get_proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) =
7178
MvNormal(current_params, sampler.σ)
72-
@inline proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} =
79+
@inline get_proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} =
7380
sampler.proposal_dist
7481

7582
# the subsequent steps of the sampler
@@ -83,7 +90,7 @@ function AbstractMCMC.step(
8390
)
8491
logdensity_function = logdensity_model.logdensity
8592
current_params = state.params
86-
proposal_dist = proposal_dist(sampler, current_params)
93+
proposal_dist = get_proposal_dist(sampler, current_params)
8794
proposed_params = rand(rng, proposal_dist)
8895
logp_proposal = only(
8996
LogDensityProblems.logdensity(logdensity_function, proposed_params)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,5 @@ include("utils.jl")
2424
include("stepper.jl")
2525
include("transducer.jl")
2626
include("logdensityproblems.jl")
27-
include("gibbs_example/gibbs.jl")
27+
include("gibbs_example/gibbs_test.jl")
2828
end

0 commit comments

Comments
 (0)