Skip to content

Commit 5e2b139

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents b0acc6f + 4fac93f commit 5e2b139

16 files changed

+138
-45
lines changed

examples/deepkernellearning.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using KernelFunctions
2+
using MLDataUtils
3+
using Zygote
4+
using Flux
5+
using Distributions, LinearAlgebra
6+
using Plots
7+
8+
Flux.@functor SqExponentialKernel
9+
Flux.@functor KernelSum
10+
Flux.@functor Matern32Kernel
11+
Flux.@functor FunctionTransform
12+
13+
neuralnet = Chain(Dense(1,3),Dense(3,2))
14+
k = SqExponentialKernel(FunctionTransform(neuralnet))
15+
xmin = -3; xmax = 3
16+
x = range(xmin,xmax,length=100)
17+
x_test = rand(Uniform(xmin,xmax),200)
18+
x,y = noisy_function(sinc,x;noise=0.1)
19+
X = reshape(x,:,1)
20+
λ = [0.1]
21+
f(x,k,λ) = kernelmatrix(k,X,x,obsdim=1)*inv(kernelmatrix(k,X,obsdim=1)+exp(λ[1])*I)*y
22+
f(X,k,1.0)
23+
loss(k,λ) = f(X,k,λ) |>->sum(y-ŷ)/length(y)+exp(λ[1])*norm(ŷ)
24+
loss(k,λ)
25+
ps = Flux.params(k)
26+
# push!(ps,λ)
27+
opt = Flux.Momentum(1.0)
28+
##
29+
for i in 1:10
30+
grads = Zygote.gradient(()->loss(k,λ),ps)
31+
Flux.Optimise.update!(opt,ps,grads)
32+
p = Plots.scatter(x,y,lab="data",title="Loss = $(loss(k,λ))")
33+
Plots.plot!(x,f(X,k,λ),lab="Prediction",lw=3.0)
34+
display(p)
35+
end

examples/kernelridgeregression.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using KernelFunctions
2+
using MLDataUtils
3+
using Zygote
4+
using Flux
5+
using Distributions, LinearAlgebra
6+
using Plots
7+
8+
Flux.@functor SqExponentialKernel
9+
Flux.@functor ScaleTransform
10+
Flux.@functor KernelSum
11+
Flux.@functor Matern32Kernel
12+
13+
xmin = -3; xmax = 3
14+
x = range(xmin,xmax,length=100)
15+
x_test = range(xmin,xmax,length=300)
16+
x,y = noisy_function(sinc,x;noise=0.1)
17+
X = reshape(x,:,1)
18+
X_test = reshape(x_test,:,1)
19+
k = SqExponentialKernel(1.0)#+Matern32Kernel(2.0)
20+
λ = [-1.0]
21+
f(x,k,λ) = kernelmatrix(k,x,X,obsdim=1)*inv(kernelmatrix(k,X,obsdim=1)+exp(λ[1])*I)*y
22+
f(X,k,1.0)
23+
loss(k,λ) = f(X,k,λ) |>->sum(y-ŷ)/length(y)+exp(λ[1])*norm(ŷ)
24+
loss(k,λ)
25+
ps = Flux.params(k)
26+
push!(ps,λ)
27+
opt = Flux.Momentum(0.1)
28+
##
29+
for i in 1:10
30+
grads = Zygote.gradient(()->loss(k,λ),ps)
31+
Flux.Optimise.update!(opt,ps,grads)
32+
p = Plots.scatter(x,y,lab="data",title="Loss = $(loss(k,λ))")
33+
Plots.plot!(x_test,f(X_test,k,λ),lab="Prediction",lw=3.0)
34+
display(p)
35+
end

examples/svm.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using KernelFunctions
2+
using MLDataUtils
3+
using Zygote
4+
using Flux
5+
using Distributions, LinearAlgebra
6+
using Plots
7+
8+
N = 100 #Number of samples
9+
μ = randn(2,2) # Random Centers
10+
xgrid = range(-3,3,length=100) # Create a grid
11+
Xgrid = hcat(collect.(Iterators.product(xgrid,xgrid))...)' #Combine into a 2D grid
12+
y = rand([-1,1],N) # Select randomly between the two classes
13+
X = zeros(N,2)
14+
X[y.==1,:] = rand(MvNormal(μ[:,1],I),count(y.==1))' #Attribute samples from class 1
15+
X[y.==-1,:] = rand(MvNormal(μ[:,2],I),count(y.==-1))' # Attribute samples from class 2
16+
17+
18+
k = SqExponentialKernel(2.0) # Create kernel function
19+
f(x,k,λ) = kernelmatrix(k,x,X,obsdim=1)*inv(kernelmatrix(k,X,obsdim=1)+exp(λ[1])*I)*y # Optimal prediction f
20+
svmloss(y,ŷ)= f(X,k,λ) |>-> sum(maximum.(0.0,1-y*ŷ)) - λ*norm(ŷ) # Total svm loss with regularisation
21+
pred = f(Xgrid,k,λ) #Compute prediction on a grid
22+
contourf(xgrid,xgrid,pred)
23+
scatter!(eachcol(X)...,color=y,lab="data")

src/KernelFunctions.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat
4-
export get_params, set_params!
3+
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat # Main matrix functions
4+
export params, duplicate, set! # Helpers
55

66
export Kernel
77
export ConstantKernel, WhiteKernel, ZeroKernel
@@ -12,8 +12,7 @@ export LinearKernel, PolynomialKernel
1212
export RationalQuadraticKernel, GammaRationalQuadraticKernel
1313
export KernelSum, KernelProduct
1414

15-
export SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
16-
15+
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1716

1817
using Distances, LinearAlgebra
1918
using SpecialFunctions: lgamma, besselk
@@ -42,6 +41,5 @@ include("kernels/kernelsum.jl")
4241
include("kernels/kernelproduct.jl")
4342

4443
include("generic.jl")
45-
include("squeeze.jl")
4644

4745
end

src/generic.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,3 @@ for kernel in [:ExponentialKernel,:SqExponentialKernel,:Matern32Kernel,:Matern52
2828
$kernel(t::Tr) where {Tr<:Transform} = $kernel{eltype(t),Tr}(t)
2929
end
3030
end
31-
32-
function set_params!(k::Kernel,x)
33-
set!(k.transform,first(x))
34-
end
35-
36-
37-
params(k::Kernel) = (params(k.transform),)
38-
opt_params(k::Kernel) = (opt_params(k.transform),)

src/kernels/kernelproduct.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ end
2020

2121
params(k::KernelProduct) = params.(k.kernels)
2222
opt_params(k::KernelProduct) = opt_params.(k.kernels)
23+
duplicate(k::KernelProduct,θ) = KernelProduct(duplicate.(k.kernels,θ))
2324

2425
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
2526
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test

src/kernels/kernelsum.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ end
2727

2828
params(k::KernelSum) = (k.weights,params.(k.kernels))
2929
opt_params(k::KernelSum) = (k.weights,opt_params.(k.kernels))
30+
duplicate(k::KernelSum,θ) = KernelSum(duplicate.(k.kernels,θ[end]),weights=first(θ))
3031

3132
Base.:+(k1::Kernel,k2::Kernel) = KernelSum([k1,k2],weights=[1.0,1.0])
3233
Base.:+(k1::KernelSum,k2::KernelSum) = KernelSum(vcat(k1.kernels,k2.kernels),weights=vcat(k1.weights,k2.weights))

src/squeeze.jl

Lines changed: 0 additions & 24 deletions
This file was deleted.

src/transform/ardtransform.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ function set!(t::ARDTransform{T},ρ::AbstractVector{T}) where {T<:Real}
2626
end
2727

2828
params(t::ARDTransform) = t.v
29-
opt_params(t::ARDTransform) = t.v
3029
dim(t::ARDTransform) = length(t.v)
3130

3231
function transform(t::ARDTransform,X::AbstractMatrix{<:Real},obsdim::Int)

src/transform/chaintransform.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ end
3232

3333
set!(t::ChainTransform,θ) = set!.(t.transforms,θ)
3434
params(t::ChainTransform) = (params.(t.transforms))
35-
opt_params(t::ChainTransform) = (opt_params.(t.transforms))
35+
duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ))
36+
3637

3738
Base.:(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
3839
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t)) #TODO add test

0 commit comments

Comments
 (0)