diff --git a/src/SparseVariationalApproximationModule.jl b/src/SparseVariationalApproximationModule.jl index 2a134b4b..bde2daf3 100644 --- a/src/SparseVariationalApproximationModule.jl +++ b/src/SparseVariationalApproximationModule.jl @@ -306,26 +306,16 @@ Statistics. PMLR, 2015. """ function AbstractGPs.elbo( sva::SparseVariationalApproximation, - fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Union{Diagonal{<:Real,<:Fill},ScalMat}}, + fx::FiniteGP, y::AbstractVector{<:Real}; num_data=length(y), quadrature=DefaultExpectationMethod(), ) - σ² = fx.Σy[1] + σ² = _get_homoscedastic_noise(fx.Σy) lik = GaussianLikelihood(σ²) return elbo(sva, LatentFiniteGP(fx, lik), y; num_data, quadrature) end -function AbstractGPs.elbo( - ::SparseVariationalApproximation, ::FiniteGP, ::AbstractVector; kwargs... -) - return error( - "The observation noise fx.Σy must be homoscedastic.\n", - "To avoid this error, construct fx using: f = GP(kernel); fx = f(x, σ²)", - ", where σ² is a positive Real.", - ) -end - """ elbo( sva::SparseVariationalApproximation, @@ -372,4 +362,13 @@ function _prior_kl(sva::SparseVariationalApproximation{NonCentered}) return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2 end +_get_homoscedastic_noise(Σy::Union{Diagonal{<:Real,<:Fill},ScalMat}) = Σy[1] +function _get_homoscedastic_noise(_) + return error( + "The observation noise fx.Σy must be homoscedastic.\n", + "To avoid this error, construct fx using: f = GP(kernel); fx = f(x, σ²)", + ", where σ² is a positive Real.\n", + ) +end + end