1
1
using Distributions
2
2
3
+ abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end
4
+
3
5
struct MHState{T}
4
6
params:: Vector{T}
5
7
logp:: Float64
@@ -9,65 +11,89 @@ struct MHTransition{T}
9
11
params:: Vector{T}
10
12
end
11
13
14
+ # Interface 1: LogDensityProblems.logdensity
15
+ # This function takes the logdensity function and the state (state is defined by the sampler package)
16
+ # and returns the logdensity. It allows for optional recomputation of the log probability.
17
+ # If recomputation is not needed, it returns the stored log probability from the state.
12
18
function AbstractMCMC. logdensity_and_state (
13
- logdensity_function, state:: MHState ; recompute_logp:: Bool = true
19
+ logdensity_function, state:: MHState ; recompute_logp= true
14
20
)
15
- if recompute_logp
16
- logp = AbstractMCMC. LogDensityProblems. logdensity (logdensity_function, state. params)
17
- return logp, MHState (state. params, logp)
21
+ return if recompute_logp
22
+ AbstractMCMC. LogDensityProblems. logdensity (logdensity_function, state. params)
18
23
else
19
- return state. logp, state
24
+ state. logp
20
25
end
21
26
end
22
27
28
+ # Interface 2: Base.vec
29
+ # This function takes a state and returns a vector of the parameter values stored in the state.
30
+ # It is part of the interface for interacting with the state object.
23
31
function Base. vec (state:: MHState )
24
32
return state. params
25
33
end
26
34
27
- struct RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
35
+ """
36
+ RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
37
+
38
+ A random walk Metropolis-Hastings sampler with a normal proposal distribution. The field σ
39
+ is the standard deviation of the proposal distribution.
40
+ """
41
+ struct RandomWalkMH{T} <: AbstractMHSampler
28
42
σ:: T
29
43
end
30
44
31
- struct IndependentMH{T} <: AbstractMCMC.AbstractSampler
45
+ """
46
+ IndependentMH{T} <: AbstractMCMC.AbstractSampler
47
+
48
+ A Metropolis-Hastings sampler with an independent proposal distribution.
49
+ """
50
+ struct IndependentMH{T} <: AbstractMHSampler
32
51
proposal_dist:: T
33
52
end
34
53
54
+ # the first step of the sampler
35
55
function AbstractMCMC. step (
36
56
rng:: AbstractRNG ,
37
57
logdensity_model:: AbstractMCMC.LogDensityModel ,
38
- sampler:: Union{RandomWalkMH,IndependentMH} ,
58
+ sampler:: AbstractMHSampler ,
39
59
args... ;
40
60
initial_params,
41
61
kwargs... ,
42
62
)
43
- return MHTransition (initial_params),
44
- MHState (
45
- initial_params,
46
- only (LogDensityProblems . logdensity (logdensity_model . logdensity, initial_params)),
47
- )
63
+ logdensity_function = logdensity_model . logdensity
64
+ transition = MHTransition (initial_params)
65
+ state = MHState ( initial_params, only ( logdensity_function (initial_params)))
66
+
67
+ return transition, state
48
68
end
49
69
70
+ @inline proposal_dist (sampler:: RandomWalkMH , current_params:: Vector{Float64} ) =
71
+ MvNormal (current_params, sampler. σ)
72
+ @inline proposal_dist (sampler:: IndependentMH , current_params:: Vector{T} ) where {T} =
73
+ sampler. proposal_dist
74
+
75
+ # the subsequent steps of the sampler
50
76
function AbstractMCMC. step (
51
77
rng:: AbstractRNG ,
52
78
logdensity_model:: AbstractMCMC.LogDensityModel ,
53
- sampler:: Union{RandomWalkMH,IndependentMH} ,
79
+ sampler:: AbstractMHSampler ,
54
80
state:: MHState ,
55
81
args... ;
56
82
kwargs... ,
57
83
)
58
- params = state . params
59
- proposal_dist =
60
- sampler isa RandomWalkMH ? MvNormal (state . params, sampler . σ) : sampler . proposal_dist
61
- proposal = rand (rng, proposal_dist)
84
+ logdensity_function = logdensity_model . logdensity
85
+ current_params = state . params
86
+ proposal_dist = proposal_dist (sampler, current_params)
87
+ proposed_params = rand (rng, proposal_dist)
62
88
logp_proposal = only (
63
- LogDensityProblems. logdensity (logdensity_model . logdensity, proposal )
89
+ LogDensityProblems. logdensity (logdensity_function, proposed_params )
64
90
)
65
91
66
92
if log (rand (rng)) <
67
- compute_log_acceptance_ratio (sampler, state, proposal , logp_proposal)
68
- return MHTransition (proposal ), MHState (proposal , logp_proposal)
93
+ compute_log_acceptance_ratio (sampler, state, proposed_params , logp_proposal)
94
+ return MHTransition (proposed_params ), MHState (proposed_params , logp_proposal)
69
95
else
70
- return MHTransition (params ), MHState (params , state. logp)
96
+ return MHTransition (current_params ), MHState (current_params , state. logp)
71
97
end
72
98
end
73
99
0 commit comments