Skip to content

Commit e3dc9a5

Browse files
authored
Add matern kernel and lean more on Distributions.jl (#16)
* Start on new abstraction * Start rewrite test files for easier spot testing, add matern file * Add matern kernel, rewrite things in terms of distributions * Start implementing MvTDist * Forgot main file * Move dimension handing to base functions * More steps towards MvTDist * Use undocumented MvTDist * Improve tests * Fix tests for matern kernel * Remove SpecialFunctions again * Adjust squared exponential to new format * Some cleanup * Add compat entries * Fix base tests * Fix test script * Small fix * More test fixes * Remove obsolte fallback
1 parent 8316151 commit e3dc9a5

File tree

13 files changed

+269
-71
lines changed

13 files changed

+269
-71
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1111

1212
[compat]
13-
julia = "1.6"
13+
julia = "1.8"
14+
Distributions = "0.25.108"
15+
KernelFunctions = "0.10.63"
16+
Reexport = "1.2.2"
1417

1518
[extras]
1619
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

examples/approx-prior-sample.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

examples/densities.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

examples/feature-functions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/KernelSpectralDensities.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@ export SpectralDensity
1212
export ShiftedRFF, DoubleRFF
1313
export ApproximateGPSample
1414

15-
# write tests to verify spectral density via fourier transforms
16-
# also add SpectralKernel (which can then be learned?)
17-
1815
include("base.jl")
1916
include("expkernels.jl")
17+
include("matern.jl")
2018
include("features.jl")
2119
include("approx_prior.jl")
2220

src/base.jl

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11

22
abstract type AbstractSpectralDensity end
33

4-
(S::AbstractSpectralDensity)(w) = error("Not implemented")
5-
64
"""
75
rand(S::AbstractSpectralDensity, [n::Int])
86
@@ -37,22 +35,63 @@ julia> k = SqExponentialKernel();
3735
julia> S = SpectralDensity(k, 1);
3836
3937
julia> S(0.0)
40-
2.5066282746310002
38+
2.5066282746310007
4139
4240
julia> S = SpectralDensity(k, 2);
4341
4442
julia> S(zeros(2))
45-
6.283185307179585
43+
6.2831853071795845
4644
```
4745
"""
48-
struct SpectralDensity{K<:KernelFunctions.Kernel} <: AbstractSpectralDensity
46+
struct SpectralDensity{K<:KernelFunctions.Kernel,D<:Distribution} <: AbstractSpectralDensity
4947
kernel::K
50-
dim::Int
48+
# dim::Int
49+
d::D
5150

5251
function SpectralDensity(kernel::KernelFunctions.Kernel, dim::Int)
5352
if dim < 1
5453
throw(ArgumentError("Dimension must be greater than 0"))
5554
end
56-
return new{typeof(kernel)}(kernel, dim)
55+
56+
sk, l = _deconstruct_kernel(kernel, dim)
57+
d = _spectral_distribution(sk, l)
58+
59+
return new{typeof(kernel),typeof(d)}(kernel, d)
60+
end
61+
end
62+
63+
function (S::SpectralDensity)(w)
64+
return pdf(S.d, w)
65+
end
66+
67+
function rand(rng::AbstractRNG, S::SpectralDensity, n::Int...)
68+
return rand(rng, S.d, n...)
69+
end
70+
71+
# ToDo: This could perhaps go into a separate file
72+
function _deconstruct_kernel(ker::KernelFunctions.SimpleKernel, dim::Int)
73+
if dim == 1
74+
l = 1.0
75+
else
76+
l = ones(dim)
77+
end
78+
return ker, l
79+
end
80+
81+
function _deconstruct_kernel(
82+
ker::TransformedKernel{<:KernelFunctions.SimpleKernel,<:ScaleTransform}, dim::Int
83+
)
84+
l = inv(only(ker.transform.s))
85+
if dim > 1
86+
l = ones(dim) * l
5787
end
88+
return ker.kernel, l
89+
end
90+
91+
function _deconstruct_kernel(ker::TransformedKernel, dim::Int)
92+
return throw(MethodError(_deconstruct_kernel, (ker, dim)))
5893
end
94+
95+
function _spectral_distribution(ker::KernelFunctions.Kernel, l)
96+
return throw(MethodError(_spectral_distribution, (ker,)))
97+
end

src/expkernels.jl

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,11 @@
33
## Squared ExponentialKernel
44

55
# ToDo: Not sure about distances? Do all work?
6-
(S::SpectralDensity{<:SqExponentialKernel})(w) = _sqexp(w, 1)
7-
8-
function (S::SpectralDensity{<:TransformedKernel{<:SqExponentialKernel,<:ScaleTransform}})(
9-
w
10-
)
11-
l = 1 / only(S.kernel.transform.s)
12-
return _sqexp(w, l^2)
13-
end
14-
15-
_sqexp(w::Real, l2::Real) = sqrt(2 * l2 * π) * exp(-2 * l2 * π^2 * w^2)
16-
function _sqexp(w::AbstractVector{<:Real}, l2::Real)
17-
d = length(w)
18-
return sqrt(2 * l2 * π)^d * exp(-2 * l2 * π^2 * dot(w, w))
19-
end
20-
21-
function rand(rng::AbstractRNG, S::SpectralDensity{<:SqExponentialKernel}, n::Int...)
22-
return _sqexprand(rng, S.dim, 1, n...)
23-
end
24-
25-
function rand(
26-
rng::AbstractRNG,
27-
S::SpectralDensity{<:TransformedKernel{<:SqExponentialKernel,<:ScaleTransform}},
28-
n::Int...,
29-
)
30-
l = 1 / only(S.kernel.transform.s)
31-
return _sqexprand(rng, S.dim, l, n...)
6+
function _spectral_distribution(ker::SqExponentialKernel, l::Real)
7+
return inv(2 * π * l) * Normal()
328
end
339

34-
function _sqexprand(rng::AbstractRNG, d::Int, l::Real, n::Int...)
35-
σ = 1 / (2 * l * π)
36-
if d == 1
37-
return rand(rng, Normal(0, σ), n...)
38-
elseif d > 1
39-
σv = ones(d) * abs2(σ)
40-
return rand(rng, MvNormal(Diagonal(σv)), n...)
41-
end
10+
function _spectral_distribution(ker::SqExponentialKernel, l::AbstractVector)
11+
σv = abs2.(inv.(2 * π * l))
12+
return MvNormal(Diagonal(σv))
4213
end

src/matern.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
###################################################
3+
## Matern Kernels
4+
5+
MaternKernels = Union{MaternKernel,Matern32Kernel,Matern52Kernel}
6+
7+
_matern_order(k::MaternKernel) = only(k.ν)
8+
_matern_order(::Matern32Kernel) = 3 / 2
9+
_matern_order(::Matern52Kernel) = 5 / 2
10+
11+
# rewrite everything as returning a distribution (kind of as originally planned)
12+
# should be able to abstract/ generalize a lot of the special casing
13+
function _spectral_distribution(kernel::MaternKernels, l::Real)
14+
ν = _matern_order(kernel)
15+
return inv(2 * π * l) * TDist(2 * ν)
16+
end
17+
18+
function _spectral_distribution(kernel::MaternKernels, l::AbstractVector)
19+
ν = _matern_order(kernel)
20+
n = length(l)
21+
l = inv.(2 * π * l) .^ 2
22+
D = Distributions.MvTDist(2 * ν, zeros(n), diagm(l))
23+
return D
24+
end

test/base.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ using KernelSpectralDensities
22
using Test
33

44
@testset "Fallback" begin
5-
ker = ZeroKernel()
5+
ker = ConstantKernel()
6+
7+
@test_throws MethodError SpectralDensity(ker, 1)
68

7-
S = SpectralDensity(ker, 1)
9+
kert = ker SelectTransform(1.0)
810

9-
@test_throws ErrorException S(1.0)
11+
@test_throws MethodError SpectralDensity(kert, 1)
1012
end
1113

1214
@testset "Dimension check" begin

test/expkernels.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,48 @@
1+
if (!@isdefined RUN_TESTS) || !RUN_TESTS
2+
using CairoMakie
3+
show_plot = true
4+
include("test_utils.jl")
5+
else
6+
show_plot = false
7+
end
8+
19
@testset "SquaredExponential Kernel" begin
210
ker = SqExponentialKernel()
311
@testset "1D" begin
412
@testset "Pure" begin
5-
w_interval = [-2.0, 2.0]
13+
# ker = SqExponentialKernel()
14+
w_interval = 2.0
615
t_interval = [0.0, 4.0]
716

8-
f = test_spectral_density(ker, w_interval, t_interval)
17+
test_spectral_density(ker, w_interval, t_interval; show_plot)
918
end
1019

1120
@testset "Scaled" begin
12-
ker = with_lengthscale(SqExponentialKernel(), 0.7)
13-
w_interval = [-2.0, 2.0]
21+
# ker = SqExponentialKernel()
22+
kert = with_lengthscale(ker, 0.7)
23+
w_interval = 2.0
1424
t_interval = [0.0, 4.0]
1525

16-
f = test_spectral_density(ker, w_interval, t_interval)
26+
f = test_spectral_density(kert, w_interval, t_interval; show_plot)
1727
end
1828
end
1929

2030
@testset "2D" begin
2131
@testset "Pure" begin
32+
# ker = SqExponentialKernel()
2233
w_interval = [-2.0, 2.0]
2334
x_interval = [-2.0, 2.0]
2435

25-
f = test_2Dspectral_density(ker, w_interval, x_interval)
36+
f = test_2Dspectral_density(ker, w_interval, x_interval; show_plot)
2637
end
2738

2839
@testset "Scaled" begin
29-
ker = with_lengthscale(SqExponentialKernel(), 0.7)
40+
# ker = SqExponentialKernel()
41+
kert = with_lengthscale(ker, 0.7)
3042
w_interval = [-2.0, 2.0]
3143
x_interval = [-2.0, 2.0]
3244

33-
f = test_2Dspectral_density(ker, w_interval, x_interval)
45+
f = test_2Dspectral_density(kert, w_interval, x_interval; show_plot)
3446
end
3547
end
3648
end

0 commit comments

Comments
 (0)