diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 95a6fdeeb..5951aeea1 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -1,7 +1,28 @@ +# Momenta sampling with different metrics + function sample_momenta(n::Int) Float64[random(normal, 0, 1) for _=1:n] end +function sample_momenta(n::Int, metric::AbstractVector) + @assert all(>(0), metric) "All diagonal metric values must be positive" + return sqrt.(metric) .* sample_momenta(n) +end + +function sample_momenta(n::Int, metric::LinearAlgebra.Diagonal) + sample_momenta(n::Int, LinearAlgebra.diag(metric)) +end + +function sample_momenta(n::Int, metric::AbstractMatrix) + mvnormal(zeros(n), metric) +end + +function sample_momenta(n::Int, metric::Nothing) + sample_momenta(n) +end + +# Assessing momenta log probabilities with different metrics + function assess_momenta(momenta) logprob = 0. for val in momenta @@ -10,21 +31,56 @@ function assess_momenta(momenta) logprob end +function assess_momenta(momenta, metric::AbstractVector) + logprob = 0. + for (val, m) in zip(momenta, metric) + logprob += logpdf(normal, val, 0, sqrt(m)) + end + logprob +end + +function assess_momenta(momenta, metric::LinearAlgebra.Diagonal) + assess_momenta(momenta, LinearAlgebra.diag(metric)) +end + +function assess_momenta(momenta, metric::AbstractMatrix) + logpdf(mvnormal, momenta, zeros(length(momenta)), metric) +end + +function assess_momenta(momenta, metric::Nothing) + assess_momenta(momenta) +end + """ (new_trace, accepted) = hmc( trace, selection::Selection; L=10, eps=0.1, - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), metric = nothing) + +Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the +selected addresses, returning the new trace (which is equal to the previous trace +if the move was not accepted) and a `Bool` indicating whether the move was accepted or not. + +Hamilton's equations are numerically integrated using leapfrog integration with +step size `eps` for `L` steps and initial momenta sampled from a Gaussian distribution with +covariance given by `metric` (mass matrix). + +Sampling with HMC is improved by using a metric/mass matrix that approximates the +**inverse** covariance of the target distribution, and is equivalent to a linear transformation +of the parameter space (see Neal, 2011). The following options are supported for `metric`: -Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not. +- `nothing` (default): identity matrix +- `Vector`: diagonal matrix with the given vector as the diagonal +- `Diagonal`: diagonal matrix lowers to the vector of the diagonal entries +- `Matrix`: dense matrix -Hamilton's equations are numerically integrated using leapfrog integration with step size `eps` for `L` steps. See equations (5.18)-(5.20) of Neal (2011). +See equations (5.18)-(5.20) of Neal (2011). # References Neal, Radford M. (2011), "MCMC Using Hamiltonian Dynamics", Handbook of Markov Chain Monte Carlo, pp. 113-162. URL: http://www.mcmchandbook.net/HandbookChapter5.pdf """ function hmc( trace::Trace, selection::Selection; L=10, eps=0.1, - check=false, observations=EmptyChoiceMap()) + check=false, observations=EmptyChoiceMap(), metric = nothing) prev_model_score = get_score(trace) args = get_args(trace) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing @@ -35,8 +91,8 @@ function hmc( (_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) values = to_array(values_trie, Float64) gradient = to_array(gradient_trie, Float64) - momenta = sample_momenta(length(values)) - prev_momenta_score = assess_momenta(momenta) + momenta = sample_momenta(length(values), metric) + prev_momenta_score = assess_momenta(momenta, metric) for step=1:L # half step on momenta @@ -60,7 +116,7 @@ function hmc( new_model_score = get_score(new_trace) # assess new momenta score (negative kinetic energy) - new_momenta_score = assess_momenta(-momenta) + new_momenta_score = assess_momenta(-momenta, metric) # accept or reject alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score diff --git a/test/inference/hmc.jl b/test/inference/hmc.jl index c5887751c..e465351b6 100644 --- a/test/inference/hmc.jl +++ b/test/inference/hmc.jl @@ -1,5 +1,6 @@ -@testset "hmc" begin - +@testset "hmc tests" begin + import LinearAlgebra, Random + # smoke test a function without retval gradient @gen function foo() x = @trace(normal(0, 1), :x) @@ -17,4 +18,135 @@ (trace, _) = generate(foo, ()) (new_trace, accepted) = hmc(trace, select(:x)) + + # For Normal(0,1), grad should be -x + (_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0) + values = to_array(values_trie, Float64) + grad = to_array(gradient_trie, Float64) + @test values ≈ -grad + + # smoke test with vector metric + @gen function bar() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + return x + y + end + + (trace, _) = generate(bar, ()) + metric_vec = [1.0, 2.0] + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) + + # smoke test with Diagonal metric + (trace, _) = generate(bar, ()) + metric_diag = LinearAlgebra.Diagonal([1.0, 2.0]) + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_diag) + + # smoke test with Dense matrix metric + (trace, _) = generate(bar, ()) + metric_dense = [1.0 0.1; 0.1 2.0] + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense) + + # smoke test with vector metric and retval gradient + @gen (grad) function bar_grad() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + return x + y + end + + (trace, _) = generate(bar_grad, ()) + metric_vec = [0.5, 1.5] + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec) + + # For each Normal(0,1), grad should be -x + (_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0) + values = to_array(values_trie, Float64) + grad = to_array(gradient_trie, Float64) + @test values ≈ -grad + + # smoke test with Diagonal metric and retval gradient + (trace, _) = generate(bar_grad, ()) + metric_diag = LinearAlgebra.Diagonal([0.5, 1.5]) + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_diag) + + # smoke test with Dense matrix metric and retval gradient + (trace, _) = generate(bar_grad, ()) + metric_dense = [0.5 0.2; 0.2 1.5] + (new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense) end + +@testset "hmc metric behavior" begin + import LinearAlgebra, Random + + # test that different metrics produce different behavior + @gen function test_metric_effect() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + return x + y + end + + (trace1, _) = generate(test_metric_effect, ()) + + + # Set RNG to a known state for comparison + Random.seed!(1) + + # Run HMC with identity metric (default) + (trace_identity, _) = hmc(trace1, select(:x, :y); L=5) + + # Reset RNG to same state for comparison + Random.seed!(1) + + # Run HMC with scaled metric (should behave differently) + metric_scaled = [10.0, 0.1] # Very different scales + (trace_scaled, _) = hmc(trace1, select(:x, :y); L=5, metric=metric_scaled) + + # With same RNG sequence but different metrics, should get different results + @test get_choices(trace_identity) != get_choices(trace_scaled) + + # With same metric but different representations, should get similar results + # Test many times to check statistical similarity + acceptances_diag = Float64[] + acceptances_dense = Float64[] + + for i in 1:50 + # Reset to predictable state for each iteration + Random.seed!(i) + (_, accepted_diag) = hmc(trace1, select(:x, :y); + metric=LinearAlgebra.Diagonal([2.0, 3.0])) + + # Reset to same state for comparison + Random.seed!(i) + (_, accepted_dense) = hmc(trace1, select(:x, :y); + metric=[2.0 0.0; 0.0 3.0]) + + # Collect acceptance results + push!(acceptances_diag, float(accepted_diag)) + push!(acceptances_dense, float(accepted_dense)) + end + + # # Should have similar acceptance rates (within 20%) + rate_diag = Distributions.mean(acceptances_diag) + rate_dense = Distributions.mean(acceptances_dense) + @test abs(rate_diag - rate_dense) < 0.2 + + +end + +@testset "Bad metric catches" begin + @gen function bar() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + return x + y + end + + bad_metrics =([-1.0 -20.0; 0.0 1.0], # Bad dense, + LinearAlgebra.Diagonal([-1.0, -20.0]), # Bad diag + [-5.0, 20.0], # Bad vector diag + ) + + for bad_metric in bad_metrics + (trace, _) = generate(bar, ()) + @test_throws Exception hmc(trace, select(:x, :y); metric=bad_metric) + end + +end \ No newline at end of file