Skip to content

Commit 63942bc

Browse files
committed
Merge branch 'master-dev'
2 parents 6fc3a6f + d095935 commit 63942bc

23 files changed

+171
-49
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.2.0"
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
910
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1011
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1112
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

README.md

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,35 @@ KernelFunctions.jl provide a flexible and complete framework for kernel function
88

99
The aim is to make the API as model-agnostic as possible while still being user-friendly.
1010

11+
## Examples
12+
13+
```julia
14+
X = reshape(collect(range(-3.0,3.0,length=100)),:,1)
15+
# Set simple scaling of the data
16+
k₁ = SqExponentialKernel(1.0)
17+
K₁ = kernelmatrix(k,X,obsdim=1)
18+
19+
# Set a function transformation on the data
20+
k₂ = MaternKernel(FunctionTransform(x->sin.(x)))
21+
K₂ = kernelmatrix(k,X,obsdim=1)
22+
23+
# Set a matrix premultiplication on the data
24+
k₃ = PolynomialKernel(LowRankTransform(randn(4,1)),0.0,2.0)
25+
K₃ = kernelmatrix(k,X,obsdim=1)
26+
27+
# Add and sum kernels
28+
k₄ = 0.5*SqExponentialKernel()*LinearKernel(0.5) + 0.4*k₂
29+
K₄ = kernelmatrix(k,X,obsdim=1)
30+
31+
heatmap([K₁,K₂,K₃,K₄],yflip=false,colorbar=false)
32+
```
33+
<p align=center>
34+
<img src="docs/src/assets/heatmap_combination.png" width=400px>
35+
</p>
36+
1137
## Objectives (by priority)
12-
- ARD Kernels
13-
- AD Compatible (Zygote, ForwardDiff, ReverseDiff)
14-
- Kernel sum and product
38+
- AD Compatibility (Zygote, ForwardDiff)
1539
- Toeplitz Matrices
1640
- BLAS backend
1741

18-
19-
Directly inspired by the [MLKernels](https://github.com/trthatcher/MLKernels.jl) package
42+
Directly inspired by the [MLKernels](https://github.com/trthatcher/MLKernels.jl) package.

docs/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
build/
22
site/
3+
4+
#Temp to avoid to many changes

docs/create_kernel_plots.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,33 @@ x₀ = 0.0; l=0.1
1010
n_grid = 101
1111
fill(x₀,n_grid,1)
1212
xrange = reshape(collect(range(-3,3,length=n_grid)),:,1)
13+
14+
k = SqExponentialKernel(1.0)
15+
K1 = kernelmatrix(k,xrange,obsdim=1)
16+
p = heatmap(K1,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
17+
savefig(joinpath(@__DIR__,"src","assets","heatmap_sqexp.png"))
18+
19+
20+
k = Matern32Kernel(FunctionTransform(x->(sin.(x)).^2))
21+
K2 = kernelmatrix(k,xrange,obsdim=1)
22+
p = heatmap(K2,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
23+
savefig(joinpath(@__DIR__,"src","assets","heatmap_matern.png"))
24+
25+
26+
k = PolynomialKernel(LowRankTransform(randn(3,1)),2.0,0.0)
27+
K3 = kernelmatrix(k,xrange,obsdim=1)
28+
p = heatmap(K3,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
29+
savefig(joinpath(@__DIR__,"src","assets","heatmap_poly.png"))
30+
31+
k = 0.5*SqExponentialKernel()*LinearKernel(0.5) + 0.4*Matern32Kernel(FunctionTransform(x->sin.(x)))
32+
K4 = kernelmatrix(k,xrange,obsdim=1)
33+
p = heatmap(K4,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
34+
savefig(joinpath(@__DIR__,"src","assets","heatmap_prodsum.png"))
35+
36+
plot(heatmap.([K1,K2,K3,K4],yflip=true,colorbar=false)...,layout=(2,2))
37+
savefig(joinpath(@__DIR__,"src","assets","heatmap_combination.png"))
38+
39+
1340
for k in [SqExponentialKernel,ExponentialKernel]
1441
K = kernelmatrix(k(),xrange,obsdim=1)
1542
v = rand(MvNormal(K+1e-7I))
67.8 KB
Loading

docs/src/assets/heatmap_matern.png

37.8 KB
Loading

docs/src/assets/heatmap_poly.png

17.8 KB
Loading

docs/src/assets/heatmap_prodsum.png

21.5 KB
Loading

docs/src/assets/heatmap_sqexp.png

7.43 KB
Loading

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using Distances, LinearAlgebra
1616
using Zygote: @adjoint
1717
using SpecialFunctions: lgamma, besselk
1818
using StatsFuns: logtwo
19+
using PDMats
1920

2021
const defaultobs = 2
2122

@@ -32,7 +33,7 @@ kernels = ["exponential","matern","polynomial","constant","rationalquad","expone
3233
for k in kernels
3334
include(joinpath("kernels",k*".jl"))
3435
end
35-
include("kernelmatrix.jl")
36+
include("matrix/kernelmatrix.jl")
3637
include("kernels/kernelsum.jl")
3738
include("kernels/kernelproduct.jl")
3839

0 commit comments

Comments
 (0)