1
1
abstract type Proposal{P} end
2
2
3
- struct StaticProposal{P} <: Proposal{P}
3
+ struct StaticProposal{issymmetric, P} <: Proposal{P}
4
4
proposal:: P
5
5
end
6
+ const SymmetricStaticProposal{P} = StaticProposal{true ,P}
6
7
7
- struct RandomWalkProposal{P} <: Proposal{P}
8
+ StaticProposal (proposal) = StaticProposal {false} (proposal)
9
+ function StaticProposal {issymmetric} (proposal) where {issymmetric}
10
+ return StaticProposal {issymmetric,typeof(proposal)} (proposal)
11
+ end
12
+
13
+ struct RandomWalkProposal{issymmetric,P} <: Proposal{P}
8
14
proposal:: P
9
15
end
16
+ const SymmetricRandomWalkProposal{P} = RandomWalkProposal{true ,P}
17
+
18
+ RandomWalkProposal (proposal) = RandomWalkProposal {false} (proposal)
19
+ function RandomWalkProposal {issymmetric} (proposal) where {issymmetric}
20
+ return RandomWalkProposal {issymmetric,typeof(proposal)} (proposal)
21
+ end
10
22
11
23
# Random draws
12
24
Base. rand (p:: Proposal , args... ) = rand (Random. GLOBAL_RNG, p, args... )
26
38
# Random Walk #
27
39
# ##############
28
40
29
- function propose (rng:: Random.AbstractRNG , p:: RandomWalkProposal , m:: DensityModel )
30
- return propose (rng, StaticProposal (p. proposal), m)
41
+ function propose (
42
+ rng:: Random.AbstractRNG ,
43
+ proposal:: RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}} ,
44
+ :: DensityModel
45
+ ) where {issymmetric}
46
+ return rand (rng, proposal)
31
47
end
32
48
33
49
function propose (
34
50
rng:: Random.AbstractRNG ,
35
- proposal:: RandomWalkProposal{<:Union{Distribution,AbstractArray}} ,
36
- model:: DensityModel ,
51
+ proposal:: RandomWalkProposal{issymmetric, <:Union{Distribution,AbstractArray}} ,
52
+ model:: DensityModel ,
37
53
t
38
- )
54
+ ) where {issymmetric}
39
55
return t + rand (rng, proposal)
40
56
end
41
57
42
58
function q (
43
- proposal:: RandomWalkProposal{<:Union{Distribution,AbstractArray}} ,
59
+ proposal:: RandomWalkProposal{issymmetric, <:Union{Distribution,AbstractArray}} ,
44
60
t,
45
61
t_cond
46
- )
62
+ ) where {issymmetric}
47
63
return logpdf (proposal, t - t_cond)
48
64
end
49
65
53
69
54
70
function propose (
55
71
rng:: Random.AbstractRNG ,
56
- proposal:: StaticProposal{<:Union{Distribution,AbstractArray}} ,
72
+ proposal:: StaticProposal{issymmetric, <:Union{Distribution,AbstractArray}} ,
57
73
model:: DensityModel ,
58
74
t= nothing
59
- )
75
+ ) where {issymmetric}
60
76
return rand (rng, proposal)
61
77
end
62
78
63
79
function q (
64
- proposal:: StaticProposal{<:Union{Distribution,AbstractArray}} ,
80
+ proposal:: StaticProposal{issymmetric, <:Union{Distribution,AbstractArray}} ,
65
81
t,
66
82
t_cond
67
- )
83
+ ) where {issymmetric}
68
84
return logpdf (proposal, t)
69
85
end
70
86
73
89
# ###########
74
90
75
91
# function definition with abstract types requires Julia 1.3 or later
76
- for T in (StaticProposal, RandomWalkProposal)
92
+ for T in (: StaticProposal , : RandomWalkProposal )
77
93
@eval begin
78
- (p:: $T{<:Function} )() = $ T (p. proposal ())
79
- (p:: $T{<:Function} )(t) = $ T (p. proposal (t))
94
+ function (p:: $T{issymmetric,<:Function} )() where {issymmetric}
95
+ return $ T {issymmetric} (p. proposal ())
96
+ end
97
+ function (p:: $T{issymmetric,<:Function} )(t) where {issymmetric}
98
+ return $ T {issymmetric} (p. proposal (t))
99
+ end
80
100
end
81
101
end
82
102
@@ -103,4 +123,69 @@ function q(
103
123
t_cond
104
124
)
105
125
return q (proposal (t_cond), t, t_cond)
106
- end
126
+ end
127
+
128
+ """
129
+ logratio_proposal_density(proposal, state, candidate)
130
+
131
+ Compute the log-ratio of the proposal densities in the Metropolis-Hastings algorithm.
132
+
133
+ The log-ratio of the proposal densities is defined as
134
+ ```math
135
+ \\ log \\ frac{g(x | x')}{g(x' | x)},
136
+ ```
137
+ where ``x`` is the current state, ``x'`` is the proposed candidate for the next state,
138
+ and ``g(y' | y)`` is the conditional probability of proposing state ``y'`` given state
139
+ ``y`` (proposal density).
140
+ """
141
+ function logratio_proposal_density (proposal:: Proposal , state, candidate)
142
+ return q (proposal, state, candidate) - q (proposal, candidate, state)
143
+ end
144
+
145
+ # ratio is always 0 for symmetric proposals
146
+ logratio_proposal_density (:: RandomWalkProposal{true} , state, candidate) = 0
147
+ logratio_proposal_density (:: StaticProposal{true} , state, candidate) = 0
148
+
149
+ # type stable implementation for `NamedTuple`s
150
+ function logratio_proposal_density (
151
+ proposals:: NamedTuple{names} , states:: NamedTuple , candidates:: NamedTuple
152
+ ) where {names}
153
+ if @generated
154
+ args = map (names) do name
155
+ :(logratio_proposal_density (
156
+ proposals[$ (QuoteNode (name))],
157
+ states[$ (QuoteNode (name))],
158
+ candidates[$ (QuoteNode (name))],
159
+ ))
160
+ end
161
+ return :(+ ($ (args... )))
162
+ else
163
+ return sum (names) do name
164
+ return logratio_proposal_density (
165
+ proposals[name], states[name], candidates[name]
166
+ )
167
+ end
168
+ end
169
+ end
170
+
171
+ # use recursion for `Tuple`s to ensure type stability
172
+ logratio_proposal_density (proposals:: Tuple{} , states:: Tuple , candidates:: Tuple ) = 0
173
+ function logratio_proposal_density (
174
+ proposals:: Tuple{<:Proposal} , states:: Tuple , candidates:: Tuple
175
+ )
176
+ return logratio_proposal_density (first (proposals), first (states), first (candidates))
177
+ end
178
+ function logratio_proposal_density (proposals:: Tuple , states:: Tuple , candidates:: Tuple )
179
+ valfirst = logratio_proposal_density (first (proposals), first (states), first (candidates))
180
+ valtail = logratio_proposal_density (
181
+ Base. tail (proposals), Base. tail (states), Base. tail (candidates)
182
+ )
183
+ return valfirst + valtail
184
+ end
185
+
186
+ # fallback for general iterators (arrays etc.) - possibly not type stable!
187
+ function logratio_proposal_density (proposals, states, candidates)
188
+ return sum (zip (proposals, states, candidates)) do (proposal, state, candidate)
189
+ return logratio_proposal_density (proposal, state, candidate)
190
+ end
191
+ end
0 commit comments