Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ include("deprecations.jl")

include("TestUtils.jl")

import ChainRulesCore: ProjectTo, Tangent
using PDMats: ScalMat
ProjectTo(x::T) where T <: ScalMat = ProjectTo{T}(; dim=x.dim, value=ProjectTo(x.value))
(pr::ProjectTo{<:ScalMat})(dx::ScalMat) = ScalMat(pr.dim, pr.value(dx.value))
(pr::ProjectTo{<:ScalMat})(dx::Tangent{<:ScalMat}) = ScalMat(pr.dim, pr.value(dx.value))

end
8 changes: 4 additions & 4 deletions src/LaplaceApproximationModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,16 @@ function _check_laplace_inputs(
end

struct LaplaceCache{
Tm<:AbstractMatrix,Tv<:AbstractVector,Td<:Diagonal,Tf<:Real,Tc<:Cholesky
Tm<:AbstractMatrix,Tv1<:AbstractVector,Tv2<:AbstractVector,Tv3<:AbstractVector,Td<:Diagonal,Tf<:Real,Tc<:Cholesky
}
K::Tm # kernel matrix
f::Tv # mode of posterior p(f | y)
f::Tv1 # mode of posterior p(f | y)
W::Td # diagonal matrix of ∂²/∂fᵢ² loglik
Wsqrt::Td # sqrt(W)
loglik::Tf # ∑ᵢlog p(yᵢ|fᵢ)
d_loglik::Tv # ∂/∂fᵢloglik
d_loglik::Tv2 # ∂/∂fᵢloglik
B_ch::Tc # cholesky(I + Wsqrt * K * Wsqrt)
a::Tv # K⁻¹ f
a::Tv3 # K⁻¹ f
end

function _laplace_train_intermediates(dist_y_given_f, ys, K, f)
Expand Down
4 changes: 2 additions & 2 deletions src/SparseVariationalApproximationModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using LinearAlgebra
using Statistics
using StatsBase
using FillArrays: Fill
using PDMats: chol_lower
using PDMats: chol_lower, ScalMat

using AbstractGPs: AbstractGPs
using AbstractGPs:
Expand Down Expand Up @@ -306,7 +306,7 @@ Statistics. PMLR, 2015.
"""
function AbstractGPs.elbo(
sva::SparseVariationalApproximation,
fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Diagonal{<:Real,<:Fill}},
fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Union{Diagonal{<:Real,<:Fill},ScalMat}},
y::AbstractVector{<:Real};
num_data=length(y),
quadrature=DefaultExpectationMethod(),
Expand Down