Skip to content

Commit af208bc

Browse files
committed
updates
1 parent 6132f0c commit af208bc

File tree

2 files changed

+65
-33
lines changed

2 files changed

+65
-33
lines changed

test/gibbs_example/hier_normal.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
abstract type AbstractHierNormal end
22

3-
struct HierNormal <: AbstractHierNormal
4-
data::NamedTuple
3+
struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal
4+
data::Tdata
55
end
66

7-
struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal
8-
data::NamedTuple
9-
conditioned_values::NamedTuple{conditioned_vars}
7+
struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <:
8+
AbstractHierNormal
9+
data::Tdata
10+
conditioned_values::Tconditioned_vars
1011
end
1112

13+
# `mu` and `tau2` are length-1 vectors to make
1214
function log_joint(; mu, tau2, x)
1315
# mu is the mean
1416
# tau2 is the variance
@@ -39,14 +41,18 @@ function AbstractMCMC.condition(hn::HierNormal, conditioned_values::NamedTuple)
3941
end
4042

4143
function LogDensityProblems.logdensity(
42-
hn::ConditionedHierNormal{names}, params::AbstractVector
44+
hier_normal_model::ConditionedHierNormal{names}, params::AbstractVector
4345
) where {names}
44-
if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2
45-
return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x)
46-
elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu
47-
return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x)
46+
variable_to_condition = only(names)
47+
data = hier_normal_model.data
48+
conditioned_values = hier_normal_model.conditioned_values
49+
50+
if variable_to_condition == :mu
51+
return log_joint(; mu=conditioned_values.mu, tau2=params, x=data.x)
52+
elseif variable_to_condition == :tau2
53+
return log_joint(; mu=params, tau2=conditioned_values.tau2, x=data.x)
4854
else
49-
error("Unsupported conditioning configuration.")
55+
error("Unsupported conditioning variable: $variable_to_condition")
5056
end
5157
end
5258

test/gibbs_example/mh.jl

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Distributions
22

3+
abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end
4+
35
struct MHState{T}
46
params::Vector{T}
57
logp::Float64
@@ -9,65 +11,89 @@ struct MHTransition{T}
911
params::Vector{T}
1012
end
1113

14+
# Interface 1: LogDensityProblems.logdensity
15+
# This function takes the logdensity function and the state (state is defined by the sampler package)
16+
# and returns the logdensity. It allows for optional recomputation of the log probability.
17+
# If recomputation is not needed, it returns the stored log probability from the state.
1218
function AbstractMCMC.logdensity_and_state(
13-
logdensity_function, state::MHState; recompute_logp::Bool=true
19+
logdensity_function, state::MHState; recompute_logp=true
1420
)
15-
if recompute_logp
16-
logp = AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params)
17-
return logp, MHState(state.params, logp)
21+
return if recompute_logp
22+
AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params)
1823
else
19-
return state.logp, state
24+
state.logp
2025
end
2126
end
2227

28+
# Interface 2: Base.vec
29+
# This function takes a state and returns a vector of the parameter values stored in the state.
30+
# It is part of the interface for interacting with the state object.
2331
function Base.vec(state::MHState)
2432
return state.params
2533
end
2634

27-
struct RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
35+
"""
36+
RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
37+
38+
A random walk Metropolis-Hastings sampler with a normal proposal distribution. The field σ
39+
is the standard deviation of the proposal distribution.
40+
"""
41+
struct RandomWalkMH{T} <: AbstractMHSampler
2842
σ::T
2943
end
3044

31-
struct IndependentMH{T} <: AbstractMCMC.AbstractSampler
45+
"""
46+
IndependentMH{T} <: AbstractMCMC.AbstractSampler
47+
48+
A Metropolis-Hastings sampler with an independent proposal distribution.
49+
"""
50+
struct IndependentMH{T} <: AbstractMHSampler
3251
proposal_dist::T
3352
end
3453

54+
# the first step of the sampler
3555
function AbstractMCMC.step(
3656
rng::AbstractRNG,
3757
logdensity_model::AbstractMCMC.LogDensityModel,
38-
sampler::Union{RandomWalkMH,IndependentMH},
58+
sampler::AbstractMHSampler,
3959
args...;
4060
initial_params,
4161
kwargs...,
4262
)
43-
return MHTransition(initial_params),
44-
MHState(
45-
initial_params,
46-
only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)),
47-
)
63+
logdensity_function = logdensity_model.logdensity
64+
transition = MHTransition(initial_params)
65+
state = MHState(initial_params, only(logdensity_function(initial_params)))
66+
67+
return transition, state
4868
end
4969

70+
@inline proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) =
71+
MvNormal(current_params, sampler.σ)
72+
@inline proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} =
73+
sampler.proposal_dist
74+
75+
# the subsequent steps of the sampler
5076
function AbstractMCMC.step(
5177
rng::AbstractRNG,
5278
logdensity_model::AbstractMCMC.LogDensityModel,
53-
sampler::Union{RandomWalkMH,IndependentMH},
79+
sampler::AbstractMHSampler,
5480
state::MHState,
5581
args...;
5682
kwargs...,
5783
)
58-
params = state.params
59-
proposal_dist =
60-
sampler isa RandomWalkMH ? MvNormal(state.params, sampler.σ) : sampler.proposal_dist
61-
proposal = rand(rng, proposal_dist)
84+
logdensity_function = logdensity_model.logdensity
85+
current_params = state.params
86+
proposal_dist = proposal_dist(sampler, current_params)
87+
proposed_params = rand(rng, proposal_dist)
6288
logp_proposal = only(
63-
LogDensityProblems.logdensity(logdensity_model.logdensity, proposal)
89+
LogDensityProblems.logdensity(logdensity_function, proposed_params)
6490
)
6591

6692
if log(rand(rng)) <
67-
compute_log_acceptance_ratio(sampler, state, proposal, logp_proposal)
68-
return MHTransition(proposal), MHState(proposal, logp_proposal)
93+
compute_log_acceptance_ratio(sampler, state, proposed_params, logp_proposal)
94+
return MHTransition(proposed_params), MHState(proposed_params, logp_proposal)
6995
else
70-
return MHTransition(params), MHState(params, state.logp)
96+
return MHTransition(current_params), MHState(current_params, state.logp)
7197
end
7298
end
7399

0 commit comments

Comments
 (0)