Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ $(TYPEDEF)
Slice sampler for the starting single leaf tree.
Slice variable is initialized.
"""
SliceTS(rng::AbstractRNG, z0::PhasePoint) = SliceTS(z0, log(rand(rng)) - energy(z0), 1)
SliceTS(rng::AbstractRNG, z0::PhasePoint) =
SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1)

"""
$(TYPEDEF)
Expand All @@ -143,7 +144,7 @@ Multinomial sampler for the starting single leaf tree.

Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L226
"""
MultinomialTS(rng::AbstractRNG, z0::PhasePoint) = MultinomialTS(z0, zero(energy(z0)))
MultinomialTS(rng::AbstractRNG, z0::PhasePoint) = MultinomialTS(z0, zero(neg_energy(z0)))

"""
$(TYPEDEF)
Expand All @@ -153,7 +154,7 @@ Create a slice sampler for a single leaf tree:
- the number of acceptable candicates is computed by comparing the slice variable against the current energy.
"""
function SliceTS(s::SliceTS, H0::AbstractFloat, zcand::PhasePoint)
return SliceTS(zcand, s.ℓu, (s.ℓu <= -energy(zcand)) ? 1 : 0)
return SliceTS(zcand, s.ℓu, Int(s.ℓu <= neg_energy(zcand)))
end

"""
Expand All @@ -163,13 +164,13 @@ Multinomial sampler for a trajectory consisting only a leaf node.
- tree weight is the (unnormalised) energy of the leaf.
"""
function MultinomialTS(s::MultinomialTS, H0::AbstractFloat, zcand::PhasePoint)
return MultinomialTS(zcand, H0 - energy(zcand))
return MultinomialTS(zcand, H0 + neg_energy(zcand))
end

function combine(rng::AbstractRNG, s1::SliceTS, s2::SliceTS)
@assert s1.ℓu == s2.ℓu "Cannot combine two slice sampler with different slice variable"
n = s1.n + s2.n
zcand = rand(rng) < s1.n / n ? s1.zcand : s2.zcand
zcand = n * rand(rng) < s1.n ? s1.zcand : s2.zcand
return SliceTS(zcand, s1.ℓu, n)
end

Expand All @@ -181,7 +182,7 @@ end

function combine(rng::AbstractRNG, s1::MultinomialTS, s2::MultinomialTS)
ℓw = logaddexp(s1.ℓw, s2.ℓw)
zcand = rand(rng) < exp(s1.ℓw - ℓw) ? s1.zcand : s2.zcand
zcand = ℓw < s1.ℓw + Random.randexp(rng) ? s1.zcand : s2.zcand
return MultinomialTS(zcand, ℓw)
end

Expand All @@ -190,10 +191,10 @@ function combine(zcand::PhasePoint, s1::MultinomialTS, s2::MultinomialTS)
return MultinomialTS(zcand, ℓw)
end

mh_accept(rng::AbstractRNG, s::SliceTS, s′::SliceTS) = rand(rng) < min(1, s′.n / s.n)
mh_accept(rng::AbstractRNG, s::SliceTS, s′::SliceTS) = s.n * rand(rng) < s′.n

function mh_accept(rng::AbstractRNG, s::MultinomialTS, s′::MultinomialTS)
return rand(rng) < min(1, exp(s′.ℓw - s.ℓw))
return s.ℓw < s′.ℓw + Random.randexp(rng)
end

"""
Expand Down Expand Up @@ -696,16 +697,16 @@ function transition(
j = 0
while !isterminated(termination) && j < τ.termination_criterion.max_depth
# Sample a direction; `-1` means left and `1` means right
v = rand(rng, [-1, 1])
if v == -1
vleft = rand(rng, Bool)
if vleft
# Create a tree with depth `j` on the left
tree′, sampler′, termination′ =
build_tree(rng, τ, h, tree.zleft, sampler, v, j, H0)
build_tree(rng, τ, h, tree.zleft, sampler, -1, j, H0)
treeleft, treeright = tree′, tree
else
# Create a tree with depth `j` on the right
tree′, sampler′, termination′ =
build_tree(rng, τ, h, tree.zright, sampler, v, j, H0)
build_tree(rng, τ, h, tree.zright, sampler, 1, j, H0)
treeleft, treeright = tree, tree′
end
# Perform a MH step and increse depth if not terminated
Expand Down Expand Up @@ -842,8 +843,8 @@ function mh_accept_ratio(
Horiginal::T,
Hproposal::T,
) where {T<:AbstractFloat}
accept = Hproposal < Horiginal + Random.randexp(rng, T)
α = min(one(T), exp(Horiginal - Hproposal))
accept = rand(rng, T) < α
return accept, α
end

Expand All @@ -852,12 +853,13 @@ function mh_accept_ratio(
Horiginal::AbstractVector{<:T},
Hproposal::AbstractVector{<:T},
) where {T<:AbstractFloat}
α = min.(one(T), exp.(Horiginal .- Hproposal))
# NOTE: There is a chance that sharing the RNG over multiple
# chains for accepting / rejecting might couple
# the chains. We need to revisit this more rigirously
# in the future. See discussions at
# https://github.com/TuringLang/AdvancedHMC.jl/pull/166#pullrequestreview-367216534
accept = rand(rng, T, length(Horiginal)) .< α
_rng = rng isa AbstractRNG ? (rng,) : rng
accept = Hproposal .< Horiginal .+ Random.randexp.(_rng, (T,))
α = min.(one(T), exp.(Horiginal .- Hproposal))
return accept, α
end
2 changes: 1 addition & 1 deletion test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Statistics: mean

@testset "AbstractMCMC w/ gdemo" begin
rng = MersenneTwister(0)
n_samples = 5_000
n_samples = 10_000
n_adapts = 5_000
θ_init = randn(rng, 2)

Expand Down
Loading