diff --git a/Project.toml b/Project.toml index 02a02f68..6ce50610 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" +KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 029be317..41b7c116 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -19,6 +19,10 @@ include("LaplaceApproximationModule.jl") @reexport using .LaplaceApproximationModule: build_laplace_objective, build_laplace_objective! +include("FiniteBasisModule.jl") +@reexport using .FiniteBasisModule: + FFApprox, FiniteBasis, DegeneratePosterior + include("deprecations.jl") include("TestUtils.jl") diff --git a/src/FiniteBasisModule.jl b/src/FiniteBasisModule.jl new file mode 100644 index 00000000..a1d777b5 --- /dev/null +++ b/src/FiniteBasisModule.jl @@ -0,0 +1,76 @@ +module FiniteBasisModule + +using KernelFunctions,LinearAlgebra, AbstractGPs, Random +import AbstractGPs: AbstractGP, FiniteGP +import Statistics +import ChainRulesCore + +struct FiniteBasis <: KernelFunctions.SimpleKernel end + +KernelFunctions.kappa(::FiniteBasis, d::Real) = d +KernelFunctions.metric(::FiniteBasis) = KernelFunctions.DotProduct() + +struct DegeneratePosterior{P,T,C} <: AbstractGP + prior::P + w_mean::T + w_prec::C +end + +weight_form(A::KernelFunctions.ColVecs) = A.X' +weight_form(A::KernelFunctions.RowVecs) = A.X + +function AbstractGPs.posterior(fx::FiniteGP{GP{M, B}}, y::AbstractVector{<:Real}) where {M, B <: FiniteBasis} + kern = fx.f.kernel + δ = y - mean(fx) + X = weight_form(fx.x) + X_prec = X' * inv(fx.Σy) + Λμ = X_prec * y + prec = cholesky(I + Symmetric(X_prec * X)) + w = prec \ Λμ + DegeneratePosterior(fx.f, w, prec) +end + +function Statistics.mean(f::DegeneratePosterior, x::AbstractVector) + w = f.w_mean + X = weight_form(x) + X * w +end + +function Statistics.cov(f::DegeneratePosterior, x::AbstractVector) + X = weight_form(x) + AbstractGPs.Xt_invA_X(f.w_prec, X') +end + +function Statistics.cov(f::DegeneratePosterior, x::AbstractVector, y::AbstractVector) + X = weight_form(x) + Y = weight_form(y) + AbstractGPs.Xt_invA_Y(X', f.w_prec, Y') +end + +function Statistics.var(f::DegeneratePosterior, x::AbstractVector) + X = weight_form(x) + AbstractGPs.diag_Xt_invA_X(f.w_prec, X') +end + +function Statistics.rand(rng::AbstractRNG, f::DegeneratePosterior, x::AbstractVector) + w = f.w_mean + X = weight_form(x) + X * (f.w_prec.U \ randn(rng, length(x))) +end + +struct RandomFourierFeature + ws::Vector{Float64} +end + +RandomFourierFeature(kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(k)) +RandomFourierFeature(rng::AbstractRNG, kern::SqExponentialKernel, k::Int) = RandomFourierFeature(randn(rng, k)) + +FFApprox(kern::Kernel, k::Int) = FiniteBasis() ∘ FunctionTransform(RandomFourierFeature(kern, k)) +FFApprox(rng::AbstractRNG, kern::Kernel, k::Int) = FiniteBasis() ∘ FunctionTransform(RandomFourierFeature(rng, kern, k)) + + +function (f::RandomFourierFeature)(x) + Float64[cos.(f.ws .* x); sin.(f.ws .* x)] .* sqrt(2/length(f.ws)) +end + +end \ No newline at end of file diff --git a/test/FiniteBasisModule.jl b/test/FiniteBasisModule.jl new file mode 100644 index 00000000..2dc3e867 --- /dev/null +++ b/test/FiniteBasisModule.jl @@ -0,0 +1,35 @@ +@testset "finite_basis" begin + rng = MersenneTwister(123456) + N = 50 + x = rand(rng, 2, N); + y = sin.(norm.(eachcol(x))) + + @testset "Verify equivalence of weight space and function space posteriors" begin + kern = FiniteBasis() + x2 = ColVecs(rand(2, N)) + + # Predict mean and covariance using weight space view + f = GP(kern) + fx = f(x, 0.001) + opt_pred = mean_and_cov(posterior(fx, y)(x2)) + + # Predict mean and covariance as normal + fx2 = GP(kern + ZeroKernel())(x, 0.001) + pred = mean_and_cov(posterior(fx2, y)(x2)) + + # The two approaches should be the same + @test all(opt_pred .≈ pred) + end + + @testset "Verify that the RFF approximation matches the exact posterior" begin + rng = MersenneTwister(12345) + rbf = SqExponentialKernel() + flat_x = rand(rng, N) + flat_x2 = rand(rng, N) + ffkern = FFApprox(rng, rbf, 200) + + opt_pred = mean_and_cov(posterior(GP(ffkern)(flat_x, 0.001), y)(flat_x2)) + pred = mean_and_cov(posterior(GP(rbf)(flat_x, 0.001), y)(flat_x2)) + @test all(isapprox.(opt_pred, pred; atol=1e-2)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index fa26a951..46009ddd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,10 @@ include("test_utils.jl") include("LaplaceApproximationModule.jl") println(" ") @info "Ran laplace tests" + + include("FiniteBasisModule.jl") + println(" ") + @info "Ran finite basis tests" end if GROUP == "All" || GROUP == "CUDA"