Skip to content

Commit 1002638

Browse files
committed
Merge branch 'master-dev'
2 parents e237992 + f3734d5 commit 1002638

16 files changed

+464
-152
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ version = "0.1.0"
55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
9+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
810
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
911

1012
[compat]
@@ -14,6 +16,7 @@ FiniteDifferences = ">= 0.7.2"
1416
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1517
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1618
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
19+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1720

1821
[targets]
1922
test = ["FiniteDifferences", "Random", "Test"]

dev/debugAD.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using KernelFunctions
2+
using Zygote, ForwardDiff, Tracker
3+
using Test, LinearAlgebra
4+
5+
dims = [10,5]
6+
A = rand(dims...)
7+
B = rand(dims...)
8+
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
9+
l = 0.1
10+
vl = l*ones(dims[1])
11+
testfunction(k,A,B) = det(kernelmatrix(k,A,B))
12+
testfunction(k,A) = sum(kernelmatrix(k,A))
13+
k = MaternKernel(vl)
14+
KernelFunctions.kappa(k,3)
15+
testfunction(SquaredExponentialKernel(vl),A)
16+
testfunction(MaternKernel(vl),A)
17+
@which kernelmatrix(MaternKernel(vl),A,B)
18+
#For debugging
19+
@info "Running Zygote gradients"
20+
Zygote.refresh()
21+
## Zygote
22+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
23+
Zygote.gradient(x->testfunction(MaternKernel(x),A),vl)
24+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)[1]
25+
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),vl)[1]
26+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
27+
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),l)
28+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
29+
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
30+
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
31+
Zygote.gradient(x->kernelmatrix(MaternKernel(x,1.0),A)[1],l)
32+
@info "Running Tracker gradients"
33+
## Tracker
34+
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(vl),x,B),A)
35+
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(l),x[:,:]),A)
36+
# # Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
37+
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
38+
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
39+
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
40+
41+
@info "Running ForwardDiff gradients"
42+
## ForwardDiff
43+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl) #
44+
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A,B),vl) #
45+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl) #
46+
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A),vl) #
47+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A,B),[l])
48+
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A,B),[l])
49+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
50+
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A),[l])

dev/matrixvsvectors.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using KernelFunctions
2+
using Stheno
3+
using Stheno: pw
4+
using BenchmarkTools
5+
using Zygote
6+
7+
# Ds = [1,2,5,10,20,50,100,200,500,1000]
8+
Ds = [1,10,100,1000]
9+
timestheno = zeros(Float64,length(Ds)); memstheno = similar(timestheno)
10+
timekf = similar(timestheno); memkf = similar(timestheno)
11+
@progress for (i,D) in enumerate(Ds)
12+
13+
A = randn(D,1000)
14+
B = randn(D,1001)
15+
16+
# Standardised eq kernel with length-scale 0.1.
17+
medkf = median(@benchmark KernelFunctions.kernelmatrix(SquaredExponentialKernel(0.01),$A,$B,obsdim=2))
18+
timekf[i] = medkf.time/1e6; memkf[i] = medkf.memory/2^20
19+
medstheno = median(@benchmark pw(eq(; l=0.1), ColsAreObs($A), ColsAreObs($B)))
20+
timestheno[i] = medstheno.time/1e6; memstheno[i] = medstheno.memory/2^20
21+
end
22+
23+
using Plots
24+
ptime = plot(Ds,timestheno,lab="Stheno",xaxis=:log,xlabel="D",ylabel="t [ms]",title="Time")
25+
plot!(Ds,timekf,lab="KernelFunctions")
26+
pmem = plot(Ds,memstheno,lab="Stheno",xaxis=:log,xlabel="D",ylabel="Mem [MB]",title="Memory Usage")
27+
plot!(Ds,memkf,lab="KernelFunctions")
28+
plot(ptime,pmem)

src/KernelFunctions.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kappa
4-
export Kernel, SquaredExponentialKernel
3+
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
4+
export Kernel, SquaredExponentialKernel, MaternKernel, Matern32Kernel, Matern52Kernel
5+
6+
export Transform, ScaleTransform
57

68
using Distances, LinearAlgebra
9+
using Zygote: @adjoint
10+
using SpecialFunctions: lgamma, besselk
11+
using StatsFuns: logtwo
712

813
const defaultobs = 2
9-
abstract type Kernel{T<:Real} end
14+
abstract type Kernel{T,Tr} end
1015

1116
include("zygote_rules.jl")
1217
include("utils.jl")
13-
include("common.jl")
18+
include("transform/transform.jl")
1419
include("kernelmatrix.jl")
1520

16-
kernels = ["squaredexponential"]
21+
kernels = ["squaredexponential","matern"]
1722
for k in kernels
1823
include(joinpath("kernels",k*".jl"))
1924
end
2025

26+
include("generic.jl")
27+
28+
2129
end

src/common.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/generic.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
@inline metric::Kernel) = κ.metric
3+
kernels =
4+
for k in [:SquaredExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel]
5+
eval(quote
6+
@inline::$k)(d::Real) = kappa(κ,d)
7+
@inline::$k)(x::AbstractVector{T},y::AbstractVector{T}) where {T} = kernel(κ,evaluate(κ.(metric),x,y))
8+
@inline::$k)(x::AbstractMatrix{T},y::AbstractMatrix{T},obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,y,obsdim=obsdim)
9+
end)
10+
end
11+
### Transform generics
12+
13+
@inline transform::Kernel) = κ.transform
14+
@inline transform::Kernel,x::AbstractVecOrMat) = transform.transform,x)
15+
@inline transform::Kernel,x::AbstractVecOrMat,obsdim::Int) = transform.transform,x,obsdim)

src/kernelmatrix.jl

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
1-
2-
function _kappamatrix!::Kernel{T}, P::AbstractMatrix{T₁}) where {T<:Real,T₁<:Real}
3-
for i in eachindex(P)
4-
@inbounds P[i] = kappa(κ, P[i])
5-
end
6-
P
7-
end
8-
9-
function _symmetric_kappamatrix!(
1+
"""
2+
```
3+
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer=2, symmetrize::Bool=true)
4+
```
5+
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
6+
"""
7+
function kernelmatrix!(
8+
K::Matrix{T₁},
109
κ::Kernel{T},
11-
P::AbstractMatrix{T₁},
12-
symmetrize::Bool
13-
) where {T<:Real,T₁<:Real}
14-
if !((n = size(P,1)) == size(P,2))
15-
throw(DimensionMismatch("Pairwise matrix must be square."))
16-
end
17-
for j = 1:n, i = (1:j)
18-
@inbounds P[i,j] = kappa(κ, P[i,j])
19-
end
20-
symmetrize ? LinearAlgebra.copytri!(P, 'U') : P
10+
X::AbstractMatrix{T₂};
11+
obsdim::Int = defaultobs,
12+
symmetrize::Bool = true
13+
) where {T,T₁<:Real,T₂<:Real}
14+
@assert check_dims(K,X,X,obsdim) "Dimensions of the target array are not consistent with X and Y"
15+
map!(x->kappa(κ,x),K,pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
2116
end
2217

23-
2418
"""
2519
```
2620
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer=2)
@@ -34,24 +28,10 @@ function kernelmatrix!(
3428
Y::AbstractMatrix{T₃};
3529
obsdim::Int = defaultobs
3630
) where {T,T₁,T₂,T₃}
37-
#TODO Check dimension consistency
38-
_kappamatrix!(κ, pairwise!(K, metric(κ), X, Y, dims=obsdim))
39-
end
40-
41-
42-
function kernelmatrix!(
43-
K::Matrix{T₁},
44-
κ::Kernel{T},
45-
X::AbstractMatrix{T₂};
46-
obsdim::Int = defaultobs,
47-
symmetrize::Bool = true
48-
) where {T,T₁<:Real,T₂<:Real}
49-
#TODO Check dimension consistency
50-
_symmetric_kappamatrix!(κ,pairwise!(K, metric(κ), X, dims=obsdim), symmetrize)
31+
@assert check_dims(K,X,Y,obsdim) "Dimensions of the target array are not consistent with X and Y"
32+
map!(x->kappa(κ,x),K,pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
5133
end
5234

53-
# Convenience Methods ======================================================================
54-
5535
"""
5636
```
5737
kernel(κ::Kernel, x, y; obsdim=2)
@@ -60,7 +40,7 @@ Apply the kernel `κ` to ``x`` and ``y`` where ``x`` and ``y`` are vectors or sc
6040
some subtype of ``Real``.
6141
"""
6242
function kernel::Kernel{T}, x::Real, y::Real) where {T}
63-
kernel(κ, T(x), T(y))
43+
kernel(κ, [T(x)], [T(y)])
6444
end
6545

6646
function kernel(
@@ -69,25 +49,23 @@ function kernel(
6949
y::AbstractArray{T₂};
7050
obsdim::Int = defaultobs
7151
) where {T,T₁<:Real,T₂<:Real}
72-
# TODO Verify dimensions
73-
kappa(κ, evaluate(metric(κ),x,y))
52+
@assert length(x) == length(y) "x and y don't have the same dimension!"
53+
kappa(κ, evaluate(metric(κ),transform(κ,x),transform(κ,y)))
7454
end
7555

7656
"""
7757
```
78-
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2, symmetrize::Bool)
58+
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2, symmetrize::Bool=true)
7959
```
8060
Calculate the kernel matrix of `X` with respect to kernel `κ`.
8161
"""
8262
function kernelmatrix(
83-
κ::Kernel{T},
84-
X::AbstractMatrix{T₁};
63+
κ::Kernel{T,<:Transform},
64+
X::AbstractMatrix;
8565
obsdim::Int = defaultobs,
8666
symmetrize::Bool = true
87-
) where {T,T₁<:Real}
88-
Tₛ = typeof(zero(eltype(X))*zero(T))
89-
m = size(X,obsdim)
90-
return kernelmatrix!(Matrix{promote_float(T,T₁)}(undef,m,m),κ,X,obsdim=obsdim,symmetrize=symmetrize)
67+
) where {T}
68+
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
9169
end
9270

9371
"""
@@ -102,13 +80,10 @@ function kernelmatrix(
10280
Y::AbstractMatrix{T₂};
10381
obsdim=defaultobs
10482
) where {T,T₁<:Real,T₂<:Real}
105-
Tₛ = typeof(zero(eltype(X))*zero(eltype(Y))*zero(T))
106-
m = size(X,obsdim)
107-
n = size(Y,obsdim)
108-
kernelmatrix!(Matrix{Tₛ}(undef,m,n),κ,X,Y,obsdim=obsdim)
83+
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
84+
return K
10985
end
11086

111-
11287
"""
11388
```
11489
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int=2)
@@ -117,8 +92,30 @@ Calculate the diagonal matrix of `X` with respect to kernel `κ`
11792
"""
11893
function kerneldiagmatrix(
11994
κ::Kernel{T},
120-
X::AbstractMatrix{T₁}
95+
X::AbstractMatrix{T₁};
96+
obsdim::Int = defaultobs
97+
) where {T,T₁}
98+
if obsdim == 1
99+
[@views kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
100+
elseif obsdim == 2
101+
[@views kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
102+
end
103+
end
104+
105+
function kerneldiagmatrix!(
106+
K::AbstractVector{T₁},
107+
κ::Kernel{T},
108+
X::AbstractMatrix{T₂};
109+
obsdim::Int = defaultobs
121110
) where {T,T₁,T₂}
122-
@error "Not implemented yet"
123-
#TODO
111+
if obsdim == 1
112+
for i in eachindex(K)
113+
@inbounds @views K[i] = kernel(κ, X[i,:],X[i,:])
114+
end
115+
else
116+
for i in eachindex(K)
117+
@inbounds @views K[i] = kernel(κ,X[:,i],X[:,i])
118+
end
119+
end
120+
return K
124121
end

0 commit comments

Comments
 (0)