Skip to content

Commit c47ade4

Browse files
committed
update code further
1 parent b262ea9 commit c47ade4

File tree

5 files changed

+23
-46
lines changed

5 files changed

+23
-46
lines changed

src/AbstractMCMC.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,36 +82,6 @@ struct MCMCSerial <: AbstractMCMCEnsemble end
8282

8383
function condition end
8484

85-
function recompute_logprob!! end
86-
87-
"""
88-
get_logprob(state)
89-
90-
Returns the log-probability of the last sampling step, stored in `state`.
91-
"""
92-
function get_logprob(state) end
93-
94-
"""
95-
set_logprob!(state, logprob)
96-
97-
Set the log-probability of the last sampling step, stored in `state`.
98-
"""
99-
function set_logprob!!(state, logprob) end
100-
101-
"""
102-
get_params(state)
103-
104-
Returns the values of the parameters in the state.
105-
"""
106-
function get_params(state) end
107-
108-
"""
109-
setparams!(state, params)
110-
111-
Set the values of the parameters in the state.
112-
"""
113-
function set_params!!(state, params) end
114-
11585
include("samplingstats.jl")
11686
include("logging.jl")
11787
include("interface.jl")

src/gibbs.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,)))
7777
function unflatten(vec::AbstractVector, variable_sizes::NamedTuple)
7878
result = Dict{Symbol,Array}()
7979
start_idx = 1
80-
for (name, size) in pairs(variable_sizes)
80+
for name in keys(variable_sizes)
81+
size = variable_sizes[name]
8182
end_idx = start_idx + prod(size) - 1
8283
result[name] = reshape(vec[start_idx:end_idx], size...)
8384
start_idx = end_idx + 1
@@ -100,7 +101,7 @@ function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
100101
trace = merge(
101102
trace,
102103
unflatten(
103-
AbstractMCMC.get_params(sub_state),
104+
vec(sub_state),
104105
NamedTuple{(parameter_variable,)}((
105106
gibbs_state.variable_sizes[parameter_variable],
106107
)),
@@ -197,9 +198,14 @@ function AbstractMCMC.step(
197198
)
198199

199200
# recompute the logdensity stored in the mcmc state, because the values might have been updated in other sub-problems
200-
sub_state = AbstractMCMC.recompute_logprob!!(
201-
cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state
202-
)
201+
updated_log_prob = LogDensityProblems.logdensity(cond_logdensity, sub_state)
202+
203+
if !hasproperty(sub_state, :logp)
204+
error(
205+
"$(typeof(sub_state)) does not have a `:logp` field, which is required by Gibbs sampling",
206+
)
207+
end
208+
sub_state = BangBang.setproperty!!(sub_state, :logp, updated_log_prob)
203209

204210
sub_state = last(
205211
AbstractMCMC.step(

test/gibbs_example/gibbs.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include("hier_normal.jl")
55
@testset "hierarchical normal with gibbs" begin
66
# generate data
77
N = 1000 # Number of data points
8-
mu_true = 0.5 # True mean
8+
mu_true = 5 # True mean
99
tau2_true = 2.0 # True variance
1010
x_data = rand(Distributions.Normal(mu_true, sqrt(tau2_true)), N)
1111

@@ -15,7 +15,8 @@ include("hier_normal.jl")
1515
samples = sample(
1616
hn,
1717
AbstractMCMC.Gibbs((
18-
mu=RandomWalkMH(1), tau2=IndependentMH(product_distribution([InverseGamma(1, 1)]))
18+
mu=RandomWalkMH(0.3),
19+
tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])),
1920
)),
2021
200000;
2122
initial_params=(mu=[0.0], tau2=[1.0]),

test/gibbs_example/gmm.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,3 @@ end
7777
function unflatten(vec::AbstractVector, group::Tuple)
7878
return NamedTuple((only(group) => vec,))
7979
end
80-
81-
function recompute_logprob!!(gmm::ConditionedGMM, vals, state)
82-
return set_logp!!(state, LogDensityProblems.logdensity(gmm, vals))
83-
end

test/gibbs_example/mh.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@ struct MHTransition{T}
99
params::Vector{T}
1010
end
1111

12-
AbstractMCMC.get_params(state::MHState) = state.params
13-
AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp)
14-
AbstractMCMC.get_logprob(state::MHState) = state.logp
15-
AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp)
12+
function AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state::MHState)
13+
# recompute the logdensity, instead of using the one stored in the state
14+
return AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params)
15+
end
16+
17+
function Base.vec(state::MHState)
18+
return state.params
19+
end
1620

1721
struct RandomWalkMH <: AbstractMCMC.AbstractSampler
1822
σ::Float64
@@ -64,7 +68,7 @@ end
6468
function compute_log_acceptance_ratio(
6569
::RandomWalkMH, state::MHState, ::Vector{Float64}, logp_proposal::Float64
6670
)
67-
return min(0, logp_proposal - AbstractMCMC.get_logprob(state))
71+
return min(0, logp_proposal - state.logp)
6872
end
6973

7074
function compute_log_acceptance_ratio(

0 commit comments

Comments
 (0)