Skip to content

Commit 85e6aa6

Browse files
devmotiongithub-actions[bot]penelopeysm
authored
Use Random.randexp (#393)
* Use `Random.randexp` * A few more changes * Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update `mh_accept_ratio` * Increase number of samples * Fix CI setup * Update mac CI settings --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
1 parent 652bb38 commit 85e6aa6

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

src/trajectory.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ $(TYPEDEF)
133133
Slice sampler for the starting single leaf tree.
134134
Slice variable is initialized.
135135
"""
136-
SliceTS(rng::AbstractRNG, z0::PhasePoint) = SliceTS(z0, log(rand(rng)) - energy(z0), 1)
136+
SliceTS(rng::AbstractRNG, z0::PhasePoint) =
137+
SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1)
137138

138139
"""
139140
$(TYPEDEF)
@@ -143,7 +144,7 @@ Multinomial sampler for the starting single leaf tree.
143144
144145
Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L226
145146
"""
146-
MultinomialTS(rng::AbstractRNG, z0::PhasePoint) = MultinomialTS(z0, zero(energy(z0)))
147+
MultinomialTS(rng::AbstractRNG, z0::PhasePoint) = MultinomialTS(z0, zero(neg_energy(z0)))
147148

148149
"""
149150
$(TYPEDEF)
@@ -153,7 +154,7 @@ Create a slice sampler for a single leaf tree:
153154
- the number of acceptable candicates is computed by comparing the slice variable against the current energy.
154155
"""
155156
function SliceTS(s::SliceTS, H0::AbstractFloat, zcand::PhasePoint)
156-
return SliceTS(zcand, s.ℓu, (s.ℓu <= -energy(zcand)) ? 1 : 0)
157+
return SliceTS(zcand, s.ℓu, Int(s.ℓu <= neg_energy(zcand)))
157158
end
158159

159160
"""
@@ -163,13 +164,13 @@ Multinomial sampler for a trajectory consisting only a leaf node.
163164
- tree weight is the (unnormalised) energy of the leaf.
164165
"""
165166
function MultinomialTS(s::MultinomialTS, H0::AbstractFloat, zcand::PhasePoint)
166-
return MultinomialTS(zcand, H0 - energy(zcand))
167+
return MultinomialTS(zcand, H0 + neg_energy(zcand))
167168
end
168169

169170
function combine(rng::AbstractRNG, s1::SliceTS, s2::SliceTS)
170171
@assert s1.ℓu == s2.ℓu "Cannot combine two slice sampler with different slice variable"
171172
n = s1.n + s2.n
172-
zcand = rand(rng) < s1.n / n ? s1.zcand : s2.zcand
173+
zcand = n * rand(rng) < s1.n ? s1.zcand : s2.zcand
173174
return SliceTS(zcand, s1.ℓu, n)
174175
end
175176

@@ -181,7 +182,7 @@ end
181182

182183
function combine(rng::AbstractRNG, s1::MultinomialTS, s2::MultinomialTS)
183184
ℓw = logaddexp(s1.ℓw, s2.ℓw)
184-
zcand = rand(rng) < exp(s1.ℓw - ℓw) ? s1.zcand : s2.zcand
185+
zcand = ℓw < s1.ℓw + Random.randexp(rng) ? s1.zcand : s2.zcand
185186
return MultinomialTS(zcand, ℓw)
186187
end
187188

@@ -190,10 +191,10 @@ function combine(zcand::PhasePoint, s1::MultinomialTS, s2::MultinomialTS)
190191
return MultinomialTS(zcand, ℓw)
191192
end
192193

193-
mh_accept(rng::AbstractRNG, s::SliceTS, s′::SliceTS) = rand(rng) < min(1, s′.n / s.n)
194+
mh_accept(rng::AbstractRNG, s::SliceTS, s′::SliceTS) = s.n * rand(rng) < s′.n
194195

195196
function mh_accept(rng::AbstractRNG, s::MultinomialTS, s′::MultinomialTS)
196-
return rand(rng) < min(1, exp(s′.ℓw - s.ℓw))
197+
return s.ℓw < s′.ℓw + Random.randexp(rng)
197198
end
198199

199200
"""
@@ -696,16 +697,16 @@ function transition(
696697
j = 0
697698
while !isterminated(termination) && j < τ.termination_criterion.max_depth
698699
# Sample a direction; `-1` means left and `1` means right
699-
v = rand(rng, [-1, 1])
700-
if v == -1
700+
vleft = rand(rng, Bool)
701+
if vleft
701702
# Create a tree with depth `j` on the left
702703
tree′, sampler′, termination′ =
703-
build_tree(rng, τ, h, tree.zleft, sampler, v, j, H0)
704+
build_tree(rng, τ, h, tree.zleft, sampler, -1, j, H0)
704705
treeleft, treeright = tree′, tree
705706
else
706707
# Create a tree with depth `j` on the right
707708
tree′, sampler′, termination′ =
708-
build_tree(rng, τ, h, tree.zright, sampler, v, j, H0)
709+
build_tree(rng, τ, h, tree.zright, sampler, 1, j, H0)
709710
treeleft, treeright = tree, tree′
710711
end
711712
# Perform a MH step and increse depth if not terminated
@@ -842,8 +843,8 @@ function mh_accept_ratio(
842843
Horiginal::T,
843844
Hproposal::T,
844845
) where {T<:AbstractFloat}
846+
accept = Hproposal < Horiginal + Random.randexp(rng, T)
845847
α = min(one(T), exp(Horiginal - Hproposal))
846-
accept = rand(rng, T) < α
847848
return accept, α
848849
end
849850

@@ -852,12 +853,13 @@ function mh_accept_ratio(
852853
Horiginal::AbstractVector{<:T},
853854
Hproposal::AbstractVector{<:T},
854855
) where {T<:AbstractFloat}
855-
α = min.(one(T), exp.(Horiginal .- Hproposal))
856856
# NOTE: There is a chance that sharing the RNG over multiple
857857
# chains for accepting / rejecting might couple
858858
# the chains. We need to revisit this more rigirously
859859
# in the future. See discussions at
860860
# https://github.com/TuringLang/AdvancedHMC.jl/pull/166#pullrequestreview-367216534
861-
accept = rand(rng, T, length(Horiginal)) .< α
861+
_rng = rng isa AbstractRNG ? (rng,) : rng
862+
accept = Hproposal .< Horiginal .+ Random.randexp.(_rng, (T,))
863+
α = min.(one(T), exp.(Horiginal .- Hproposal))
862864
return accept, α
863865
end

test/abstractmcmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Statistics: mean
33

44
@testset "AbstractMCMC w/ gdemo" begin
55
rng = MersenneTwister(0)
6-
n_samples = 5_000
6+
n_samples = 10_000
77
n_adapts = 5_000
88
θ_init = randn(rng, 2)
99

0 commit comments

Comments
 (0)