@@ -133,7 +133,8 @@ $(TYPEDEF)
133133Slice sampler for the starting single leaf tree.
134134Slice variable is initialized.
135135"""
136- SliceTS (rng:: AbstractRNG , z0:: PhasePoint ) = SliceTS (z0, log (rand (rng)) - energy (z0), 1 )
136+ SliceTS (rng:: AbstractRNG , z0:: PhasePoint ) =
137+ SliceTS (z0, neg_energy (z0) - Random. randexp (rng), 1 )
137138
138139"""
139140$(TYPEDEF)
@@ -143,7 +144,7 @@ Multinomial sampler for the starting single leaf tree.
143144
144145Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L226
145146"""
146- MultinomialTS (rng:: AbstractRNG , z0:: PhasePoint ) = MultinomialTS (z0, zero (energy (z0)))
147+ MultinomialTS (rng:: AbstractRNG , z0:: PhasePoint ) = MultinomialTS (z0, zero (neg_energy (z0)))
147148
148149"""
149150$(TYPEDEF)
@@ -153,7 +154,7 @@ Create a slice sampler for a single leaf tree:
153154- the number of acceptable candicates is computed by comparing the slice variable against the current energy.
154155"""
155156function SliceTS (s:: SliceTS , H0:: AbstractFloat , zcand:: PhasePoint )
156- return SliceTS (zcand, s. ℓu, (s. ℓu <= - energy (zcand)) ? 1 : 0 )
157+ return SliceTS (zcand, s. ℓu, Int (s. ℓu <= neg_energy (zcand)))
157158end
158159
159160"""
@@ -163,13 +164,13 @@ Multinomial sampler for a trajectory consisting only a leaf node.
163164- tree weight is the (unnormalised) energy of the leaf.
164165"""
165166function MultinomialTS (s:: MultinomialTS , H0:: AbstractFloat , zcand:: PhasePoint )
166- return MultinomialTS (zcand, H0 - energy (zcand))
167+ return MultinomialTS (zcand, H0 + neg_energy (zcand))
167168end
168169
169170function combine (rng:: AbstractRNG , s1:: SliceTS , s2:: SliceTS )
170171 @assert s1. ℓu == s2. ℓu " Cannot combine two slice sampler with different slice variable"
171172 n = s1. n + s2. n
172- zcand = rand (rng) < s1. n / n ? s1. zcand : s2. zcand
173+ zcand = n * rand (rng) < s1. n ? s1. zcand : s2. zcand
173174 return SliceTS (zcand, s1. ℓu, n)
174175end
175176
181182
182183function combine (rng:: AbstractRNG , s1:: MultinomialTS , s2:: MultinomialTS )
183184 ℓw = logaddexp (s1. ℓw, s2. ℓw)
184- zcand = rand (rng) < exp ( s1. ℓw - ℓw ) ? s1. zcand : s2. zcand
185+ zcand = ℓw < s1. ℓw + Random . randexp (rng ) ? s1. zcand : s2. zcand
185186 return MultinomialTS (zcand, ℓw)
186187end
187188
@@ -190,10 +191,10 @@ function combine(zcand::PhasePoint, s1::MultinomialTS, s2::MultinomialTS)
190191 return MultinomialTS (zcand, ℓw)
191192end
192193
193- mh_accept (rng:: AbstractRNG , s:: SliceTS , s′:: SliceTS ) = rand (rng) < min ( 1 , s′. n / s . n)
194+ mh_accept (rng:: AbstractRNG , s:: SliceTS , s′:: SliceTS ) = s . n * rand (rng) < s′. n
194195
195196function mh_accept (rng:: AbstractRNG , s:: MultinomialTS , s′:: MultinomialTS )
196- return rand (rng) < min ( 1 , exp ( s′. ℓw - s . ℓw) )
197+ return s . ℓw < s′. ℓw + Random . randexp (rng )
197198end
198199
199200"""
@@ -696,16 +697,16 @@ function transition(
696697 j = 0
697698 while ! isterminated (termination) && j < τ. termination_criterion. max_depth
698699 # Sample a direction; `-1` means left and `1` means right
699- v = rand (rng, [ - 1 , 1 ] )
700- if v == - 1
700+ vleft = rand (rng, Bool )
701+ if vleft
701702 # Create a tree with depth `j` on the left
702703 tree′, sampler′, termination′ =
703- build_tree (rng, τ, h, tree. zleft, sampler, v , j, H0)
704+ build_tree (rng, τ, h, tree. zleft, sampler, - 1 , j, H0)
704705 treeleft, treeright = tree′, tree
705706 else
706707 # Create a tree with depth `j` on the right
707708 tree′, sampler′, termination′ =
708- build_tree (rng, τ, h, tree. zright, sampler, v , j, H0)
709+ build_tree (rng, τ, h, tree. zright, sampler, 1 , j, H0)
709710 treeleft, treeright = tree, tree′
710711 end
711712 # Perform a MH step and increse depth if not terminated
@@ -842,8 +843,8 @@ function mh_accept_ratio(
842843 Horiginal:: T ,
843844 Hproposal:: T ,
844845) where {T<: AbstractFloat }
846+ accept = Hproposal < Horiginal + Random. randexp (rng, T)
845847 α = min (one (T), exp (Horiginal - Hproposal))
846- accept = rand (rng, T) < α
847848 return accept, α
848849end
849850
@@ -852,12 +853,13 @@ function mh_accept_ratio(
852853 Horiginal:: AbstractVector{<:T} ,
853854 Hproposal:: AbstractVector{<:T} ,
854855) where {T<: AbstractFloat }
855- α = min .(one (T), exp .(Horiginal .- Hproposal))
856856 # NOTE: There is a chance that sharing the RNG over multiple
857857 # chains for accepting / rejecting might couple
858858 # the chains. We need to revisit this more rigirously
859859 # in the future. See discussions at
860860 # https://github.com/TuringLang/AdvancedHMC.jl/pull/166#pullrequestreview-367216534
861- accept = rand (rng, T, length (Horiginal)) .< α
861+ _rng = rng isa AbstractRNG ? (rng,) : rng
862+ accept = Hproposal .< Horiginal .+ Random. randexp .(_rng, (T,))
863+ α = min .(one (T), exp .(Horiginal .- Hproposal))
862864 return accept, α
863865end
0 commit comments