Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/expectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,16 @@ function expected_loglikelihood(
mc::MonteCarloExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
# take `n_samples` reparameterised samples
f_μ = mean.(q_f)
fs = f_μ .+ std.(q_f) .* randn(eltype(f_μ), length(q_f), mc.n_samples)
lls = loglikelihood.(lik.(fs), y)
r = randn(typeof(mean(first(q_f))), length(q_f), mc.n_samples)
lls = _mc_exp_loglikelihood_kernel.(_maybe_ref(lik), q_f, y, r)
return sum(lls) / mc.n_samples
end

function _mc_exp_loglikelihood_kernel(lik, q_f, y, r)
f = mean(q_f) + std(q_f) * r
return loglikelihood(lik(f), y)
end

# Compute the expected_loglikelihood over a collection of observations and marginal distributions
function expected_loglikelihood(
gh::GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
Expand All @@ -92,14 +96,21 @@ function expected_loglikelihood(
# type stable. Compared to other type stable implementations, e.g.
# using a custom two-argument pairwise sum, this is faster to
# differentiate using Zygote.
A = loglikelihood.(lik.(sqrt2 .* std.(q_f) .* gh.xs' .+ mean.(q_f)), y) .* gh.ws'
A = _gh_exp_loglikelihood_kernel.(_maybe_ref(lik), q_f, y, gh.xs', gh.ws')
return invsqrtπ * sum(A)
end

function _gh_exp_loglikelihood_kernel(lik, q_f, y, x, w)
return loglikelihood(lik(sqrt2 * std(q_f) * x + mean(q_f)), y) * w
end

function expected_loglikelihood(
::AnalyticExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
return error(
"No analytic solution exists for $(typeof(lik)). Use `DefaultExpectationMethod`, `GaussHermiteExpectation` or `MonteCarloExpectation` instead.",
)
end

_maybe_ref(lik) = Ref(lik)
_maybe_ref(liks::AbstractArray) = liks
14 changes: 12 additions & 2 deletions src/likelihoods/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,18 @@ function expected_loglikelihood(
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum(-f_μ - y .* exp.((var.(q_f) / 2) .- f_μ))
return sum(_exp_exp_loglikelihood_kernel.(q_f, y))
end

function expected_loglikelihood(
::AnalyticExpectation,
::AbstractVector{<:ExponentialLikelihood{ExpLink}},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(_exp_exp_loglikelihood_kernel.(q_f, y))
end

_exp_exp_loglikelihood_kernel(q_f, y) = -mean(q_f) - y * exp((var(q_f) / 2) - mean(q_f))

default_expectation_method(::ExponentialLikelihood{ExpLink}) = AnalyticExpectation()
20 changes: 15 additions & 5 deletions src/likelihoods/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,21 @@ function expected_loglikelihood(
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum(
(lik.α - 1) * log.(y) .- y .* exp.((var.(q_f) / 2) .- f_μ) .- lik.α * f_μ .-
loggamma(lik.α),
)
return sum(_gamma_exp_loglikelihood_kernel.(lik.α, q_f, y))
end

function expected_loglikelihood(
::AnalyticExpectation,
liks::AbstractVector{<:GammaLikelihood{ExpLink}},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(_gamma_exp_loglikelihood_kernel.(getfield.(liks, :α), q_f, y))
end

function _gamma_exp_loglikelihood_kernel(α, q_f, y)
return (α - 1) * log(y) - y * exp((var(q_f) / 2) - mean(q_f)) - α * mean(q_f) -
loggamma(α)
end

default_expectation_method(::GammaLikelihood{ExpLink}) = AnalyticExpectation()
17 changes: 14 additions & 3 deletions src/likelihoods/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,20 @@ function expected_loglikelihood(
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(
-0.5 * (log(2π) .+ log.(lik.σ²) .+ ((y .- mean.(q_f)) .^ 2 .+ var.(q_f)) / lik.σ²)
)
return sum(_gaussian_exp_loglikelihood_kernel.(lik.σ², q_f, y))
end

function expected_loglikelihood(
::AnalyticExpectation,
liks::AbstractVector{<:GaussianLikelihood},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(_gaussian_exp_loglikelihood_kernel.(only.(getfield.(liks, :σ²)), q_f, y))
end

function _gaussian_exp_loglikelihood_kernel(σ², q_f, y)
return -0.5 * (log(2π) + log(σ²) + ((y - mean(q_f))^2 + var(q_f)) / σ²)
end

default_expectation_method(::GaussianLikelihood) = AnalyticExpectation()
Expand Down
16 changes: 14 additions & 2 deletions src/likelihoods/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,20 @@ function expected_loglikelihood(
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum((y .* f_μ) - exp.(f_μ .+ (var.(q_f) / 2)) - loggamma.(y .+ 1))
return sum(_poisson_exp_loglikelihood_kernel.(q_f, y))
end

function expected_loglikelihood(
::AnalyticExpectation,
::AbstractArray{<:PoissonLikelihood{ExpLink}},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(_poisson_exp_loglikelihood_kernel.(q_f, y))
end

function _poisson_exp_loglikelihood_kernel(q_f, y)
return (y * mean(q_f)) - exp(mean(q_f) + (var(q_f) / 2)) - loggamma(y + 1)
end

default_expectation_method(::PoissonLikelihood{ExpLink}) = AnalyticExpectation()
24 changes: 24 additions & 0 deletions test/expectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
m.lik for m in implementation_types if
m.quadrature == GPLikelihoods.AnalyticExpectation && m.lik != Any
]
filter!(x -> !(x <: AbstractArray), analytic_likelihoods)
for lik_type in analytic_likelihoods
lik_type_instances = filter(lik -> isa(lik, lik_type), likelihoods_to_test)
@test !isempty(lik_type_instances)
Expand Down Expand Up @@ -120,4 +121,27 @@
)
@test isfinite(glogα)
end

@testset "non-constant likelihood" begin
@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
liks = fill(lik, 10)
# Test that the various methods of computing expectations return the same
# result.
methods = [
GaussHermiteExpectation(100),
MonteCarloExpectation(1e7),
GPLikelihoods.DefaultExpectationMethod(),
]
def = GPLikelihoods.default_expectation_method(lik)
if def isa GPLikelihoods.AnalyticExpectation
push!(methods, def)
end
y = [rand(rng, lik(0.)) for lik in liks]

results = map(
m -> GPLikelihoods.expected_loglikelihood(m, liks, q_f, y), methods
)
@test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results)
end
end
end