Skip to content

Commit 3b3eb63

Browse files
committed
Debugged the Matern Kernel 3/2
1 parent 5fda2ba commit 3b3eb63

File tree

6 files changed

+19
-14
lines changed

6 files changed

+19
-14
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ version = "0.1.0"
55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
9+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
810
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
911

1012
[compat]

dev/debugAD.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
using KernelFunctions
22
using Zygote, ForwardDiff, Tracker
3-
using Test
3+
using Test, LinearAlgebra
44

55
dims = [10,5]
66
A = rand(dims...)
77
B = rand(dims...)
88
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
99
l = 0.1
1010
vl = l*ones(dims[1])
11-
testfunction(k,A,B) = sum(kernelmatrix(k,A,B))
11+
testfunction(k,A,B) = det(kernelmatrix(k,A,B))
1212
testfunction(k,A) = sum(kernelmatrix(k,A))
1313
k = MaternKernel(vl)
1414
KernelFunctions.kappa(k,3)
@@ -27,6 +27,8 @@ Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
2727
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),l)
2828
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
2929
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
30+
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
31+
Zygote.gradient(x->kernelmatrix(MaternKernel(x,1.0),A)[1],l)
3032
@info "Running Tracker gradients"
3133
## Tracker
3234
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(vl),x,B),A)

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export Transform, ScaleTransform
88
using Distances, LinearAlgebra
99
using Zygote: @adjoint
1010
using SpecialFunctions: lgamma, besselk
11+
using StatsFuns: logtwo
1112

1213
const defaultobs = 2
1314
abstract type Kernel{T,Tr} end

src/kernelmatrix.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ function kernelmatrix(
9494
# Tₖ = typeof(zero(eltype(X))*zero(T))
9595
# m = size(X,obsdim)
9696
#WARNING TEMP FIX
97-
= transform(κ,X,obsdim)
98-
K = map(x->kappa(κ,x),pairwise(metric(κ),X̂,X̂,dims=obsdim))
99-
# K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
97+
# X̂ = transform(κ,X,obsdim)
98+
# K = map(x->kappa(κ,x),pairwise(metric(κ),X̂,X̂,dims=obsdim))
99+
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
100100
return K
101101
end
102102

src/kernels/matern.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct MaternKernel{T,Tr<:Transform} <: Kernel{T,Tr}
2828
metric::SemiMetric
2929
ν::Real
3030
function MaternKernel{T,Tr}(transform::Tr::Real) where {T,Tr<:Transform}
31-
return new{T,Tr}(transform,SqEuclidean(),ν)
31+
return new{T,Tr}(transform,Euclidean(),ν)
3232
end
3333
end
3434

@@ -71,25 +71,25 @@ function MaternKernel(t::T₁,ν::T₂=1.5) where {T₁<:Transform,T₂<:Real}
7171
end
7272
end
7373

74-
@inline kappa::MaternKernel, d::Real) where {T} = exp((1.0-κ.ν)*log2 - lgamma.ν) - κ.ν*log(sqrt(2κ.ν*d²)))*besselk.ν,sqrt(2κ.ν*d²))
74+
@inline kappa::MaternKernel, d::Real) where {T} = exp((1.0-κ.ν)*logtwo - lgamma.ν) - κ.ν*log(sqrt(2κ.ν)*d))*besselk.ν,sqrt(2κ.ν)*d)
7575

7676

7777
struct Matern3_2Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
7878
transform::Tr
7979
metric::SemiMetric
8080
function Matern3_2Kernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
81-
return new{T,Tr}(transform,SqEuclidean())
81+
return new{T,Tr}(transform,Euclidean())
8282
end
8383
end
8484

85-
@inline kappa::Matern3_2Kernel, d²::T) where {T<:Real} = (1+sqrt(3*d²))*exp(-sqrt(3*d²))
85+
@inline kappa::Matern3_2Kernel, d::T) where {T<:Real} = (1+sqrt(3)*d)*exp(-sqrt(3)*d)
8686

8787
struct Matern5_2Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
8888
transform::Tr
8989
metric::SemiMetric
9090
function Matern5_2Kernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
91-
return new{T,Tr}(transform,SqEuclidean())
91+
return new{T,Tr}(transform,Euclidean())
9292
end
9393
end
9494

95-
@inline kappa::Matern5_2Kernel, d²::Real) where {T} = (1+sqrt(5*d²)+5*d²/3)*exp(-sqrt(5*d²))
95+
@inline kappa::Matern5_2Kernel, d::Real) where {T} = (1+sqrt(5)*d+5*d^2/3)*exp(-sqrt(5)*d)

test/testAD.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using KernelFunctions
22
using Zygote, ForwardDiff, Tracker
3-
using Test
3+
using Test, LinearAlgebra
44

55
dims = [10,5]
66

@@ -10,8 +10,8 @@ K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
1010
kernels = [SquaredExponentialKernel,MaternKernel]
1111
l = 2.0
1212
vl = l*ones(dims[1])
13-
testfunction(k,A,B) = sum(kernelmatrix(k,A,B))
14-
testfunction(k,A) = sum(kernelmatrix(k,A))
13+
testfunction(k,A,B) = det(kernelmatrix(k,A,B))
14+
testfunction(k,A) = det(kernelmatrix(k,A))
1515

1616
##Eventually store real results in file
1717
@testset "Zygote Automatic Differentiation test" begin

0 commit comments

Comments
 (0)