Skip to content

Commit 8ceb309

Browse files
authored
move expected loglik to GPLikelihoods (#123)
* move to GPLikelihoods 0.4; remove expected_loglik from here * remove unnecessary internal function
1 parent 490ece8 commit 8ceb309

File tree

8 files changed

+24
-305
lines changed

8 files changed

+24
-305
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ApproximateGPs"
22
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.3.4"
4+
version = "0.4.0"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
@@ -29,7 +29,7 @@ Distributions = "0.25"
2929
FastGaussQuadrature = "0.4"
3030
FillArrays = "0.12, 0.13"
3131
ForwardDiff = "0.10"
32-
GPLikelihoods = "0.3"
32+
GPLikelihoods = "0.4"
3333
IrrationalConstants = "0.1"
3434
LogExpFunctions = "0.3"
3535
PDMats = "0.11"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44

55
[compat]
6-
ApproximateGPs = "0.3"
6+
ApproximateGPs = "0.3,0.4"
77
Documenter = "0.27"

src/ApproximateGPs.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ include("utils.jl")
1313
include("SparseVariationalApproximationModule.jl")
1414
@reexport using .SparseVariationalApproximationModule:
1515
SparseVariationalApproximation, Centered, NonCentered
16-
@reexport using .SparseVariationalApproximationModule:
17-
DefaultQuadrature, Analytic, GaussHermite, MonteCarlo
1816

1917
include("LaplaceApproximationModule.jl")
2018
@reexport using .LaplaceApproximationModule: LaplaceApproximation

src/SparseVariationalApproximationModule.jl

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@ using ..API
44

55
export SparseVariationalApproximation, Centered, NonCentered
66

7-
using ..ApproximateGPs: _chol_cov, _cov
87
using Distributions
98
using LinearAlgebra
109
using Statistics
1110
using StatsBase
1211
using FillArrays: Fill
1312
using PDMats: chol_lower
14-
using IrrationalConstants: sqrt2, invsqrtπ
1513

1614
using AbstractGPs: AbstractGPs
1715
using AbstractGPs:
@@ -24,10 +22,8 @@ using AbstractGPs:
2422
marginals,
2523
At_A,
2624
diag_At_A
27-
using GPLikelihoods: GaussianLikelihood
28-
29-
export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo
30-
include("expected_loglik.jl")
25+
using GPLikelihoods: GaussianLikelihood, DefaultExpectationMethod, expected_loglikelihood
26+
using ..ApproximateGPs: _chol_cov, _cov
3127

3228
@doc raw"""
3329
Centered()
@@ -289,22 +285,20 @@ end
289285
fx::FiniteGP,
290286
y::AbstractVector{<:Real};
291287
num_data=length(y),
292-
quadrature=DefaultQuadrature(),
288+
quadrature=GPLikelihoods.DefaultExpectationMethod(),
293289
)
294290
295291
Compute the Evidence Lower BOund from [1] for the process `f = fx.f ==
296292
svgp.fz.f` where `y` are observations of `fx`, pseudo-inputs are given by `z =
297293
svgp.fz.x` and `q(u)` is a variational distribution over inducing points `u =
298294
f(z)`.
299295
300-
`quadrature` selects which method is used to calculate the expected loglikelihood in
301-
the ELBO. The options are: `DefaultQuadrature()`, `Analytic()`, `GaussHermite()` and
302-
`MonteCarlo()`. For likelihoods with an analytic solution, `DefaultQuadrature()` uses this
303-
exact solution. If there is no such solution, `DefaultQuadrature()` either uses
304-
`GaussHermite()` or `MonteCarlo()`, depending on the likelihood.
296+
`quadrature` is passed on to `GPLikelihoods.expected_loglikelihood` and selects
297+
which method is used to calculate the expected loglikelihood in the ELBO. See
298+
`GPLikelihoods.expected_loglikelihood` for more details.
305299
306300
N.B. the likelihood is assumed to be Gaussian with observation noise `fx.Σy`.
307-
Further, `fx.Σy` must be isotropic - i.e. `fx.Σy = α * I`.
301+
Further, `fx.Σy` must be isotropic - i.e. `fx.Σy = σ² * I`.
308302
309303
[1] - Hensman, James, Alexander Matthews, and Zoubin Ghahramani. "Scalable
310304
variational Gaussian process classification." Artificial Intelligence and
@@ -315,10 +309,11 @@ function AbstractGPs.elbo(
315309
fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Diagonal{<:Real,<:Fill}},
316310
y::AbstractVector{<:Real};
317311
num_data=length(y),
318-
quadrature=DefaultQuadrature(),
312+
quadrature=DefaultExpectationMethod(),
319313
)
320-
@assert sva.fz.f === fx.f
321-
return _elbo(quadrature, sva, fx, y, GaussianLikelihood(fx.Σy[1]), num_data)
314+
σ² = fx.Σy[1]
315+
lik = GaussianLikelihood(σ²)
316+
return elbo(sva, LatentFiniteGP(fx, lik), y; num_data, quadrature)
322317
end
323318

324319
function AbstractGPs.elbo(
@@ -337,7 +332,7 @@ end
337332
lfx::LatentFiniteGP,
338333
y::AbstractVector;
339334
num_data=length(y),
340-
quadrature=DefaultQuadrature(),
335+
quadrature=GPLikelihoods.DefaultExpectationMethod(),
341336
)
342337
343338
Compute the ELBO for a LatentGP with a possibly non-conjugate likelihood.
@@ -347,26 +342,17 @@ function AbstractGPs.elbo(
347342
lfx::LatentFiniteGP,
348343
y::AbstractVector;
349344
num_data=length(y),
350-
quadrature=DefaultQuadrature(),
351-
)
352-
@assert sva.fz.f === lfx.fx.f
353-
return _elbo(quadrature, sva, lfx.fx, y, lfx.lik, num_data)
354-
end
355-
356-
# Compute the common elements of the ELBO
357-
function _elbo(
358-
quadrature::QuadratureMethod,
359-
sva::SparseVariationalApproximation,
360-
fx::FiniteGP,
361-
y::AbstractVector,
362-
lik,
363-
num_data::Integer,
345+
quadrature=DefaultExpectationMethod(),
364346
)
365-
@assert sva.fz.f === fx.f
347+
sva.fz.f === lfx.fx.f || throw(
348+
ArgumentError(
349+
"(Latent)FiniteGP prior is not consistent with SparseVariationalApproximation's",
350+
),
351+
)
366352

367353
f_post = posterior(sva)
368-
q_f = marginals(f_post(fx.x))
369-
variational_exp = expected_loglik(quadrature, y, q_f, lik)
354+
q_f = marginals(f_post(lfx.fx.x))
355+
variational_exp = expected_loglikelihood(quadrature, lfx.lik, q_f, y)
370356

371357
n_batch = length(y)
372358
scale = num_data / n_batch

src/expected_loglik.jl

Lines changed: 0 additions & 168 deletions
This file was deleted.

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1717

1818
[compat]
1919
AbstractGPs = "0.4, 0.5"
20-
ApproximateGPs = "0.3"
20+
ApproximateGPs = "0.4"
2121
ChainRulesCore = "1"
2222
ChainRulesTestUtils = "1.2.3"
2323
Distributions = "0.25"

0 commit comments

Comments
 (0)