Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 63 additions & 7 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
# Momenta sampling with different metrics

function sample_momenta(n::Int)
Float64[random(normal, 0, 1) for _=1:n]
end

function sample_momenta(n::Int, metric::AbstractVector)
@assert all(>(0), metric) "All diagonal metric values must be positive"
return sqrt.(metric) .* sample_momenta(n)
end

function sample_momenta(n::Int, metric::LinearAlgebra.Diagonal)
sample_momenta(n::Int, LinearAlgebra.diag(metric))
end

function sample_momenta(n::Int, metric::AbstractMatrix)
mvnormal(zeros(n), metric)
end

function sample_momenta(n::Int, metric::Nothing)
sample_momenta(n)
end

# Assessing momenta log probabilities with different metrics

function assess_momenta(momenta)
logprob = 0.
for val in momenta
Expand All @@ -10,21 +31,56 @@ function assess_momenta(momenta)
logprob
end

function assess_momenta(momenta, metric::AbstractVector)
logprob = 0.
for (val, m) in zip(momenta, metric)
logprob += logpdf(normal, val, 0, sqrt(m))
end
logprob
end

function assess_momenta(momenta, metric::LinearAlgebra.Diagonal)
assess_momenta(momenta, LinearAlgebra.diag(metric))
end

function assess_momenta(momenta, metric::AbstractMatrix)
logpdf(mvnormal, momenta, zeros(length(momenta)), metric)
end

function assess_momenta(momenta, metric::Nothing)
assess_momenta(momenta)
end

"""
(new_trace, accepted) = hmc(
trace, selection::Selection; L=10, eps=0.1,
check=false, observations=EmptyChoiceMap())
check=false, observations=EmptyChoiceMap(), metric = nothing)

Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the
selected addresses, returning the new trace (which is equal to the previous trace
if the move was not accepted) and a `Bool` indicating whether the move was accepted or not.

Hamilton's equations are numerically integrated using leapfrog integration with
step size `eps` for `L` steps and initial momenta sampled from a Gaussian distribution with
covariance given by `metric` (mass matrix).

Sampling with HMC is improved by using a metric/mass matrix that approximates the
**inverse** covariance of the target distribution, and is equivalent to a linear transformation
of the parameter space (see Neal, 2011). The following options are supported for `metric`:

Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not.
- `nothing` (default): identity matrix
- `Vector`: diagonal matrix with the given vector as the diagonal
- `Diagonal`: diagonal matrix lowers to the vector of the diagonal entries
- `Matrix`: dense matrix

Hamilton's equations are numerically integrated using leapfrog integration with step size `eps` for `L` steps. See equations (5.18)-(5.20) of Neal (2011).
See equations (5.18)-(5.20) of Neal (2011).

# References
Neal, Radford M. (2011), "MCMC Using Hamiltonian Dynamics", Handbook of Markov Chain Monte Carlo, pp. 113-162. URL: http://www.mcmchandbook.net/HandbookChapter5.pdf
"""
function hmc(
trace::Trace, selection::Selection; L=10, eps=0.1,
check=false, observations=EmptyChoiceMap())
check=false, observations=EmptyChoiceMap(), metric = nothing)
prev_model_score = get_score(trace)
args = get_args(trace)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
Expand All @@ -35,8 +91,8 @@ function hmc(
(_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
values = to_array(values_trie, Float64)
gradient = to_array(gradient_trie, Float64)
momenta = sample_momenta(length(values))
prev_momenta_score = assess_momenta(momenta)
momenta = sample_momenta(length(values), metric)
prev_momenta_score = assess_momenta(momenta, metric)
for step=1:L

# half step on momenta
Expand All @@ -60,7 +116,7 @@ function hmc(
new_model_score = get_score(new_trace)

# assess new momenta score (negative kinetic energy)
new_momenta_score = assess_momenta(-momenta)
new_momenta_score = assess_momenta(-momenta, metric)

# accept or reject
alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
Expand Down
136 changes: 134 additions & 2 deletions test/inference/hmc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testset "hmc" begin

@testset "hmc tests" begin
import LinearAlgebra, Random

# smoke test a function without retval gradient
@gen function foo()
x = @trace(normal(0, 1), :x)
Expand All @@ -17,4 +18,135 @@

(trace, _) = generate(foo, ())
(new_trace, accepted) = hmc(trace, select(:x))

# For Normal(0,1), grad should be -x
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0)
values = to_array(values_trie, Float64)
grad = to_array(gradient_trie, Float64)
@test values ≈ -grad

# smoke test with vector metric
@gen function bar()
x = @trace(normal(0, 1), :x)
y = @trace(normal(0, 1), :y)
return x + y
end

(trace, _) = generate(bar, ())
metric_vec = [1.0, 2.0]
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec)

# smoke test with Diagonal metric
(trace, _) = generate(bar, ())
metric_diag = LinearAlgebra.Diagonal([1.0, 2.0])
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_diag)

# smoke test with Dense matrix metric
(trace, _) = generate(bar, ())
metric_dense = [1.0 0.1; 0.1 2.0]
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense)

# smoke test with vector metric and retval gradient
@gen (grad) function bar_grad()
x = @trace(normal(0, 1), :x)
y = @trace(normal(0, 1), :y)
return x + y
end

(trace, _) = generate(bar_grad, ())
metric_vec = [0.5, 1.5]
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec)

# For each Normal(0,1), grad should be -x
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0)
values = to_array(values_trie, Float64)
grad = to_array(gradient_trie, Float64)
@test values ≈ -grad

# smoke test with Diagonal metric and retval gradient
(trace, _) = generate(bar_grad, ())
metric_diag = LinearAlgebra.Diagonal([0.5, 1.5])
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_diag)

# smoke test with Dense matrix metric and retval gradient
(trace, _) = generate(bar_grad, ())
metric_dense = [0.5 0.2; 0.2 1.5]
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense)
end

@testset "hmc metric behavior" begin
import LinearAlgebra, Random

# test that different metrics produce different behavior
@gen function test_metric_effect()
x = @trace(normal(0, 1), :x)
y = @trace(normal(0, 1), :y)
return x + y
end

(trace1, _) = generate(test_metric_effect, ())


# Set RNG to a known state for comparison
Random.seed!(1)

# Run HMC with identity metric (default)
(trace_identity, _) = hmc(trace1, select(:x, :y); L=5)

# Reset RNG to same state for comparison
Random.seed!(1)

# Run HMC with scaled metric (should behave differently)
metric_scaled = [10.0, 0.1] # Very different scales
(trace_scaled, _) = hmc(trace1, select(:x, :y); L=5, metric=metric_scaled)

# With same RNG sequence but different metrics, should get different results
@test get_choices(trace_identity) != get_choices(trace_scaled)

# With same metric but different representations, should get similar results
# Test many times to check statistical similarity
acceptances_diag = Float64[]
acceptances_dense = Float64[]

for i in 1:50
# Reset to predictable state for each iteration
Random.seed!(i)
(_, accepted_diag) = hmc(trace1, select(:x, :y);
metric=LinearAlgebra.Diagonal([2.0, 3.0]))

# Reset to same state for comparison
Random.seed!(i)
(_, accepted_dense) = hmc(trace1, select(:x, :y);
metric=[2.0 0.0; 0.0 3.0])

# Collect acceptance results
push!(acceptances_diag, float(accepted_diag))
push!(acceptances_dense, float(accepted_dense))
end

# # Should have similar acceptance rates (within 20%)
rate_diag = Distributions.mean(acceptances_diag)
rate_dense = Distributions.mean(acceptances_dense)
@test abs(rate_diag - rate_dense) < 0.2


end

@testset "Bad metric catches" begin
@gen function bar()
x = @trace(normal(0, 1), :x)
y = @trace(normal(0, 1), :y)
return x + y
end

bad_metrics =([-1.0 -20.0; 0.0 1.0], # Bad dense,
LinearAlgebra.Diagonal([-1.0, -20.0]), # Bad diag
[-5.0, 20.0], # Bad vector diag
)

for bad_metric in bad_metrics
(trace, _) = generate(bar, ())
@test_throws Exception hmc(trace, select(:x, :y); metric=bad_metric)
end

end
Loading