diff --git a/examples/c-comparisons/script.jl b/examples/c-comparisons/script.jl index 011c17fd..05cdbfd5 100644 --- a/examples/c-comparisons/script.jl +++ b/examples/c-comparisons/script.jl @@ -1,7 +1,7 @@ -# # Binary Classification with Laplace approximation +# # Binary Classification with different approximations # # This example demonstrates how to carry out non-conjugate Gaussian process -# inference using the Laplace approximation. +# inference using Laplace approximation and Expectation Propagation. # # For a basic introduction to the functionality of this library, please refer # to the [User Guide](@ref). @@ -78,6 +78,8 @@ lf = build_latent_gp(theta0) lf.f.kernel +# ## Approximate inference with Laplace approximation +# # We can now compute the Laplace approximation ``q(f)`` to the true posterior # ``p(f | y)``: @@ -104,7 +106,7 @@ plot_samples!(Xgrid, f_post) # process regression, the maximization objective is the marginal likelihood. # Here, we can only optimise an _approximation_ to the marginal likelihood. -# ## Optimise the hyperparameters +# ### Optimise the hyperparameters # # ApproximateGPs provides a convenience function `build_laplace_objective` that # constructs an objective function for optimising the hyperparameters, based on @@ -137,3 +139,28 @@ f_post2 = posterior(LaplaceApproximation(; f_init=objective.f), lf2(X), Y) p2 = plot_data() plot_samples!(Xgrid, f_post2) + +# ## Approximate inference with Expectation Propagation (EP) +# +# !!! warning +# The EP implementation is currently just an experimental prototype and may +# not work for your use-case. Any help welcome! + +# For initial hyperparameter values: + +f_post_ep = posterior(ApproximateGPs.ExpectationPropagation(), lf(X), Y) + +p3 = plot_data() +plot_samples!(Xgrid, f_post_ep) + +# For optimized hyperparameter values: +# +# !!! warning +# The approximate (log) marginal likelihood for EP has not yet been +# implemented. Here we re-use the optimized hyperparameters from the +# Laplace approximation for illustration purposes. + +f_post_ep2 = posterior(ApproximateGPs.ExpectationPropagation(), lf2(X), Y) + +p4 = plot_data() +plot_samples!(Xgrid, f_post_ep2) diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 6d5e4d95..66332a33 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -21,6 +21,10 @@ include("LaplaceApproximationModule.jl") @reexport using .LaplaceApproximationModule: build_laplace_objective, build_laplace_objective! +include("ExpectationPropagationModule.jl") +using .ExpectationPropagationModule: ExpectationPropagation +#@reexport using .ExpectationPropagationModule: ExpectationPropagation # still too experimental + include("deprecations.jl") include("TestUtils.jl") diff --git a/src/ExpectationPropagationModule.jl b/src/ExpectationPropagationModule.jl new file mode 100644 index 00000000..cd05582d --- /dev/null +++ b/src/ExpectationPropagationModule.jl @@ -0,0 +1,212 @@ +module ExpectationPropagationModule + +using ..API + +export ExpectationPropagation + +using LinearAlgebra +using Random: randperm + +using Distributions +using FastGaussQuadrature: gausshermite +using IrrationalConstants: log2π, sqrt2π, sqrt2, invsqrtπ +using Statistics +using StatsBase + +using AbstractGPs +using AbstractGPs: LatentFiniteGP + +struct ExpectationPropagation + maxiter::Int + epsilon::Float64 + n_gh::Int +end + +function ExpectationPropagation(; maxiter=100, epsilon=1e-6, n_gh=150) + return ExpectationPropagation(maxiter, epsilon, n_gh) +end + +function AbstractGPs.posterior(ep::ExpectationPropagation, lfx::LatentFiniteGP, ys) + ep_state = ep_inference(ep, lfx, ys) + # NOTE: here we simply piggyback on the SparseVariationalApproximation. + # Should AbstractGPs provide its own "GP conditioned on f(x) ~ q(f)" rather + # than just "conditioned on observation under some noise" (which is *not* + # the same thing...)? + return posterior(SparseVariationalApproximation(Centered(), lfx.fx, ep_state.q)) +end + +function ep_inference(ep::ExpectationPropagation, lfx::LatentFiniteGP, ys) + fx = lfx.fx + mean(fx) == zero(mean(fx)) || + error("non-zero prior mean currently not supported: discuss on GitHub issue #89") + length(ys) == length(fx) || error( + "ExpectationPropagation currently does not support multi-latent likelihoods; please open an issue on GitHub", + ) + dist_y_given_f = lfx.lik + K = cov(fx) + + return ep_inference(dist_y_given_f, ys, K; ep) +end + +function ep_inference(dist_y_given_f, ys, K; ep=nothing) + ep_problem = EPProblem(dist_y_given_f, ys, K; ep) + ep_state = EPState(ep_problem) + return ep_outer_loop(ep_problem, ep_state) +end + +function EPProblem(ep::ExpectationPropagation, p::MvNormal, lik_evals::AbstractVector) + return (; p, lik_evals, ep) +end + +function EPProblem(dist_y_given_f, ys, K; ep=nothing) + f_prior = MvNormal(K) + lik_evals = [f -> pdf(dist_y_given_f(f), y) for y in ys] + return EPProblem(ep, f_prior, lik_evals) +end + +function EPState(ep_problem, q::MvNormal, sites::AbstractVector) + return (; ep_problem, q, sites) +end + +function EPState(ep_problem) + N = length(ep_problem.lik_evals) + # TODO- manually keep track of canonical parameters and initialize precision to 0 + sites = [ + (; Z=NaN, log_Ztilde=NaN, q=NormalCanon(0.0, 1e-10), cav=NormalCanon(0.0, 1.0)) for + _ in 1:N + ] + q = ep_problem.p + return EPState(ep_problem, q, sites) +end + +function ep_approx_posterior(prior, sites::AbstractVector) + canon_site_dists = [convert(NormalCanon, t.q) for t in sites] + potentials = [q.η for q in canon_site_dists] + precisions = [q.λ for q in canon_site_dists] + ts_dist = MvNormalCanon(potentials, Diagonal(precisions)) + return mul_dist(prior, ts_dist) +end + +function ep_outer_loop(ep_problem, ep_state; maxiter=ep_problem.ep.maxiter) + for i in 1:maxiter + @info "Outer loop iteration $i" + new_state = ep_loop_over_sites(ep_problem, ep_state) + if ep_converged(ep_state.sites, new_state.sites; epsilon=ep_problem.ep.epsilon) + @info "converged" + break + else + ep_state = new_state + end + end + return ep_state +end + +function ep_converged(old_sites, new_sites; epsilon=1e-6) + # TODO improve convergence check + diff1 = [(t_old.q.η - t_new.q.η)^2 for (t_old, t_new) in zip(old_sites, new_sites)] + diff2 = [(t_old.q.λ - t_new.q.λ)^2 for (t_old, t_new) in zip(old_sites, new_sites)] + return mean(diff1) < epsilon && mean(diff2) < epsilon +end + +function ep_loop_over_sites(ep_problem, ep_state) + # TODO: randomize order of updates: make configurable? + for i in randperm(length(ep_problem.lik_evals)) + @info " Inner loop iteration $i" + new_site = ep_single_site_update(ep_problem, ep_state, i) + + # TODO: rank-1 update + new_sites = deepcopy(ep_state.sites) + new_sites[i] = new_site + new_q = meanform(ep_approx_posterior(ep_problem.p, new_sites)) + ep_state = EPState(ep_problem, new_q, new_sites) + end + return ep_state +end + +function ep_single_site_update(ep_problem, ep_state, i::Int) + q_fi = ith_marginal(ep_state.q, i) + alik_i = epsite_dist(ep_state.sites[i]) + cav_i = div_dist(q_fi, alik_i) + qhat_i = moment_match(cav_i, ep_problem.lik_evals[i]; n_points=ep_problem.ep.n_gh) + Zhat = qhat_i.Z + new_t = div_dist(qhat_i.q, cav_i) + var_sum = var(cav_i) + var(new_t) + Ztilde = Zhat * sqrt2π * sqrt(var_sum) * exp((mean(cav_i) - mean(new_t))^2 / (2var_sum)) + log_Ztilde = + log(Zhat) + + log2π / 2 + + log(var_sum) / 2 + + (mean(cav_i) - mean(new_t))^2 / (2var_sum) + return (; Z=Ztilde, log_Ztilde=log_Ztilde, q=new_t, cav=cav_i) # cav_i only required by approx_lml test +end + +function ith_marginal(d::Union{MvNormal,MvNormalCanon}, i::Int) + m = mean(d) + v = var(d) + return Normal(m[i], sqrt(v[i])) +end + +function mul_dist(a::NormalCanon, b::NormalCanon) + # NormalCanon + # η::T # σ^(-2) * μ + # λ::T # σ^(-2) + etaAmulB = a.η + b.η + lambdaAmulB = a.λ + b.λ + return NormalCanon(etaAmulB, lambdaAmulB) +end + +mul_dist(a, b) = mul_dist(convert(NormalCanon, a), convert(NormalCanon, b)) + +function mul_dist(a::MvNormalCanon, b::MvNormalCanon) + # MvNormalCanon + # h::V # potential vector, i.e. inv(Σ) * μ + # J::P # precision matrix, i.e. inv(Σ) + hAmulB = a.h + b.h + JAmulB = a.J + b.J + return MvNormalCanon(hAmulB, JAmulB) +end + +mul_dist(a::MvNormal, b) = mul_dist(canonform(a), b) + +function div_dist(a::NormalCanon, b::NormalCanon) + # NormalCanon + # η::T # σ^(-2) * μ + # λ::T # σ^(-2) + etaAdivB = a.η - b.η + lambdaAdivB = a.λ - b.λ + return NormalCanon(etaAdivB, lambdaAdivB) +end + +div_dist(a::Normal, b) = div_dist(convert(NormalCanon, a), b) +div_dist(a, b::Normal) = div_dist(a, convert(NormalCanon, b)) + +#function EPSite(Z, m, s2) +# return (; Z, m, s2) +#end +# +#function epsite_dist(site) +# return Normal(site.m, sqrt(site.s2)) +#end + +epsite_dist(site) = site.q + +function epsite_pdf(site, f) + return site.Z * pdf(epsite_dist(site), f) +end + +function moment_match(cav_i::Union{Normal,NormalCanon}, lik_eval_i; n_points=150) + # TODO: combine with expected_loglik / move into GPLikelihoods + xs, ws = gausshermite(n_points) + fs = (sqrt2 * std(cav_i)) .* xs .+ mean(cav_i) + lik_ws = lik_eval_i.(fs) .* ws + fs_lik_ws = fs .* lik_ws + m0 = invsqrtπ * sum(lik_ws) + m1 = invsqrtπ * sum(fs_lik_ws) + m2 = invsqrtπ * dot(fs_lik_ws, fs) + matched_Z = m0 + matched_mean = m1 / m0 + matched_var = m2 / m0 - matched_mean^2 + return (; Z=matched_Z, q=Normal(matched_mean, sqrt(matched_var))) +end + +end diff --git a/src/LaplaceApproximationModule.jl b/src/LaplaceApproximationModule.jl index d43615f7..bb1aae6a 100644 --- a/src/LaplaceApproximationModule.jl +++ b/src/LaplaceApproximationModule.jl @@ -154,8 +154,11 @@ function _check_laplace_inputs( lfx::LatentFiniteGP, ys; f_init=nothing, maxiter=100, newton_kwargs... ) fx = lfx.fx - @assert mean(fx) == zero(mean(fx)) # might work with non-zero prior mean but not checked - @assert length(ys) == length(fx) # LaplaceApproximation currently does not support multi-latent likelihoods + mean(fx) == zero(mean(fx)) || + error("non-zero prior mean currently not supported: discuss on GitHub issue #89") + length(ys) == length(fx) || error( + "LaplaceApproximation currently does not support multi-latent likelihoods; please open an issue on GitHub", + ) dist_y_given_f = lfx.lik K = cov(fx) if isnothing(f_init) @@ -337,7 +340,9 @@ function ChainRulesCore.rrule(::typeof(newton_inner_loop), dist_y_given_f, ys, K ) # ∂K = df/dK Δf - ∂K = @thunk(cache.Wsqrt * (cache.B_ch \ (cache.Wsqrt \ Δf_opt)) * cache.d_loglik') + ∂K = ChainRulesCore.@thunk( + cache.Wsqrt * (cache.B_ch \ (cache.Wsqrt \ Δf_opt)) * cache.d_loglik' + ) return (∂self, ∂dist_y_given_f, ∂ys, ∂K) end diff --git a/test/ExpectationPropagationModule.jl b/test/ExpectationPropagationModule.jl new file mode 100644 index 00000000..67d0ab2a --- /dev/null +++ b/test/ExpectationPropagationModule.jl @@ -0,0 +1,30 @@ +@testset "Expectation Propagation" begin + @testset "moment_match" begin + function moment_match_quadgk(cav_i::UnivariateDistribution, lik_eval_i) + lower = mean(cav_i) - 20 * std(cav_i) + upper = mean(cav_i) + 20 * std(cav_i) + m0, _ = quadgk(f -> pdf(cav_i, f) * lik_eval_i(f), lower, upper) + m1, _ = quadgk(f -> f * pdf(cav_i, f) * lik_eval_i(f), lower, upper) + m2, _ = quadgk(f -> f^2 * pdf(cav_i, f) * lik_eval_i(f), lower, upper) + matched_Z = m0 + matched_mean = m1 / m0 + matched_var = m2 / m0 - matched_mean^2 + return (; Z=matched_Z, q=Normal(matched_mean, sqrt(matched_var))) + end + + cav_i = Normal(0.8231, 3.213622) # random numbers + lik_eval_i = f -> pdf(Bernoulli(logistic(f)), true) + Z_gh, q_gh = ExpectationPropagationModule.moment_match( + cav_i, lik_eval_i; n_points=100 + ) + Z_quad, q_quad = moment_match_quadgk(cav_i, lik_eval_i) + @test Z_gh ≈ Z_quad + @test mean(q_gh) ≈ mean(q_quad) + @test std(q_gh) ≈ std(q_quad) + end + + @testset "predictions" begin + approx = ApproximateGPs.ExpectationPropagation(; n_gh=500) + ApproximateGPs.TestUtils.test_approximation_predictions(approx) + end +end diff --git a/test/LaplaceApproximationModule.jl b/test/LaplaceApproximationModule.jl index a3f705c2..a543d64b 100644 --- a/test/LaplaceApproximationModule.jl +++ b/test/LaplaceApproximationModule.jl @@ -28,6 +28,8 @@ end @testset "predictions" begin + # in Gaussian case, Laplace converges to f_opt in one step; we need the + # second step to compute the cache at f_opt rather than f_init! approx = LaplaceApproximation(; maxiter=2) ApproximateGPs.TestUtils.test_approximation_predictions(approx) end diff --git a/test/Project.toml b/test/Project.toml index 39e43130..2e2e3e7a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Optim = "429524aa-4258-5aef-a3af-852621145aeb" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -27,4 +28,5 @@ IterTools = "1" LogExpFunctions = "0.3" Optim = "1" PDMats = "0.11" +QuadGK = "2" Zygote = "0.6" diff --git a/test/runtests.jl b/test/runtests.jl index a1ec05f4..10d2d664 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,12 @@ using Zygote using AbstractGPs using ApproximateGPs -using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximationModule +using ApproximateGPs: + SparseVariationalApproximationModule, + LaplaceApproximationModule, + ExpectationPropagationModule + +using QuadGK # Writing tests: # 1. The file structure of the test should match precisely the file structure of src. @@ -62,4 +67,8 @@ include("test_utils.jl") include("LaplaceApproximationModule.jl") println(" ") @info "Ran laplace tests" + + include("ExpectationPropagationModule.jl") + println(" ") + @info "Ran ep tests" end