@@ -4,14 +4,12 @@ using ..API
44
55export SparseVariationalApproximation, Centered, NonCentered
66
7- using .. ApproximateGPs: _chol_cov, _cov
87using Distributions
98using LinearAlgebra
109using Statistics
1110using StatsBase
1211using FillArrays: Fill
1312using PDMats: chol_lower
14- using IrrationalConstants: sqrt2, invsqrtπ
1513
1614using AbstractGPs: AbstractGPs
1715using AbstractGPs:
@@ -24,10 +22,8 @@ using AbstractGPs:
2422 marginals,
2523 At_A,
2624 diag_At_A
27- using GPLikelihoods: GaussianLikelihood
28-
29- export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo
30- include (" expected_loglik.jl" )
25+ using GPLikelihoods: GaussianLikelihood, DefaultExpectationMethod, expected_loglikelihood
26+ using .. ApproximateGPs: _chol_cov, _cov
3127
3228@doc raw """
3329 Centered()
@@ -289,22 +285,20 @@ end
289285 fx::FiniteGP,
290286 y::AbstractVector{<:Real};
291287 num_data=length(y),
292- quadrature=DefaultQuadrature (),
288+ quadrature=GPLikelihoods.DefaultExpectationMethod (),
293289 )
294290
295291Compute the Evidence Lower BOund from [1] for the process `f = fx.f ==
296292svgp.fz.f` where `y` are observations of `fx`, pseudo-inputs are given by `z =
297293svgp.fz.x` and `q(u)` is a variational distribution over inducing points `u =
298294f(z)`.
299295
300- `quadrature` selects which method is used to calculate the expected loglikelihood in
301- the ELBO. The options are: `DefaultQuadrature()`, `Analytic()`, `GaussHermite()` and
302- `MonteCarlo()`. For likelihoods with an analytic solution, `DefaultQuadrature()` uses this
303- exact solution. If there is no such solution, `DefaultQuadrature()` either uses
304- `GaussHermite()` or `MonteCarlo()`, depending on the likelihood.
296+ `quadrature` is passed on to `GPLikelihoods.expected_loglikelihood` and selects
297+ which method is used to calculate the expected loglikelihood in the ELBO. See
298+ `GPLikelihoods.expected_loglikelihood` for more details.
305299
306300N.B. the likelihood is assumed to be Gaussian with observation noise `fx.Σy`.
307- Further, `fx.Σy` must be isotropic - i.e. `fx.Σy = α * I`.
301+ Further, `fx.Σy` must be isotropic - i.e. `fx.Σy = σ² * I`.
308302
309303[1] - Hensman, James, Alexander Matthews, and Zoubin Ghahramani. "Scalable
310304variational Gaussian process classification." Artificial Intelligence and
@@ -315,10 +309,11 @@ function AbstractGPs.elbo(
315309 fx:: FiniteGP{<:AbstractGP,<:AbstractVector,<:Diagonal{<:Real,<:Fill}} ,
316310 y:: AbstractVector{<:Real} ;
317311 num_data= length (y),
318- quadrature= DefaultQuadrature (),
312+ quadrature= DefaultExpectationMethod (),
319313)
320- @assert sva. fz. f === fx. f
321- return _elbo (quadrature, sva, fx, y, GaussianLikelihood (fx. Σy[1 ]), num_data)
314+ σ² = fx. Σy[1 ]
315+ lik = GaussianLikelihood (σ²)
316+ return elbo (sva, LatentFiniteGP (fx, lik), y; num_data, quadrature)
322317end
323318
324319function AbstractGPs. elbo (
337332 lfx::LatentFiniteGP,
338333 y::AbstractVector;
339334 num_data=length(y),
340- quadrature=DefaultQuadrature (),
335+ quadrature=GPLikelihoods.DefaultExpectationMethod (),
341336 )
342337
343338Compute the ELBO for a LatentGP with a possibly non-conjugate likelihood.
@@ -347,26 +342,17 @@ function AbstractGPs.elbo(
347342 lfx:: LatentFiniteGP ,
348343 y:: AbstractVector ;
349344 num_data= length (y),
350- quadrature= DefaultQuadrature (),
351- )
352- @assert sva. fz. f === lfx. fx. f
353- return _elbo (quadrature, sva, lfx. fx, y, lfx. lik, num_data)
354- end
355-
356- # Compute the common elements of the ELBO
357- function _elbo (
358- quadrature:: QuadratureMethod ,
359- sva:: SparseVariationalApproximation ,
360- fx:: FiniteGP ,
361- y:: AbstractVector ,
362- lik,
363- num_data:: Integer ,
345+ quadrature= DefaultExpectationMethod (),
364346)
365- @assert sva. fz. f === fx. f
347+ sva. fz. f === lfx. fx. f || throw (
348+ ArgumentError (
349+ " (Latent)FiniteGP prior is not consistent with SparseVariationalApproximation's" ,
350+ ),
351+ )
366352
367353 f_post = posterior (sva)
368- q_f = marginals (f_post (fx. x))
369- variational_exp = expected_loglik (quadrature, y , q_f, lik )
354+ q_f = marginals (f_post (lfx . fx. x))
355+ variational_exp = expected_loglikelihood (quadrature, lfx . lik , q_f, y )
370356
371357 n_batch = length (y)
372358 scale = num_data / n_batch
0 commit comments