Skip to content

Commit 6d8afad

Browse files
committed
Improved tests considerably and corrected a lot of bugs
1 parent efe525c commit 6d8afad

13 files changed

+202
-39
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[![Build Status](https://travis-ci.org/theogf/KernelFunctions.jl.svg?branch=master)](https://travis-ci.org/theogf/AugmentedGaussianProcesses.jl)
2+
[![Coverage Status](https://coveralls.io/repos/github/theogf/KernelFunctions.jl/badge.svg?branch=master)](https://coveralls.io/github/theogf/KernelFunctions.jl?branch=master)
23
[![Documentation](https://img.shields.io/badge/docs-dev-blue.svg)](https://theogf.github.io/KernelFunctions.jl/dev/)
34
# KernelFunctions.jl (WIP)
45
Julia Package for kernel functions for machine learning

docs/src/transform.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
11
# Transform
2+
3+
`Transform` is the object that takes care of transforming the input data before distances are being computed. It can be as standard as `IdentityTransform` returning the same input, can be a scalar with `ScaleTransform` multiplying the vectors by a scalar or a vector.
4+
There is a more general `Transform`: `FunctionTransform` that uses a function and apply it on each vector via `mapslices`.
5+
You can also create a pipeline of `Transform` via `TransformChain`. For example `LowRankTransform(rand(10,5))∘ScaleTransform(2.0)`.
6+
7+
## Transforms :
8+
9+
```@docs
10+
IdentityTransform
11+
ScaleTransform
12+
LowRankTransform
13+
FunctionTransform
14+
TransformChain
15+
```

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ module KernelFunctions
22

33
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
44
export Kernel
5+
export ConstantKernel, WhiteKernel, ZeroKernel
56
export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel
7+
export ExponentiatedKernel
68
export MaternKernel, Matern32Kernel, Matern52Kernel
79
export LinearKernel, PolynomialKernel
8-
export ConstantKernel, WhiteKernel, ZeroKernel
910

1011

1112

src/generic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
for k in [:ExponentialKernel,:SqExponentialKernel,:GammaExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel,:LinearKernel,:PolynomialKernel,:ExponentiatedKernel,:ZeroKernel,:WhiteKernel,:ConstantKernel,:RationalQuadraticKernel,:GammaRationalQuadraticKernel]
55
@eval begin
66
@inline::$k)(d::Real) = kappa(κ,d)
7-
@inline::$k)(x::AbstractVector{T},y::AbstractVector{T}) where {T} = kernel(κ,evaluate(κ.(metric),x,y))
8-
@inline::$k)(x::AbstractMatrix{T},y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,y,obsdim=obsdim)
9-
@inline::$k)(x::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,obsdim=obsdim)
7+
@inline::$k)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate.metric,transform(κ,x),transform(κ,y)))
8+
@inline::$k)(X::AbstractMatrix{T},Y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,Y,obsdim=obsdim)
9+
@inline::$k)(X::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,obsdim=obsdim)
1010
end
1111
end
1212

src/kernels/constant.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ struct ZeroKernel{T,Tr} <: Kernel{T,Tr}
88
transform::Tr
99
metric::Delta
1010
function ZeroKernel{T,Tr}(t::Tr) where {T,Tr<:Transform}
11-
new{eltype{Tr},Tr}(t,Delta())
11+
new{T,Tr}(t,Delta())
1212
end
1313
end
1414

15+
function ZeroKernel(t::Tr=IdentityTransform()) where {Tr<:Transform}
16+
ZeroKernel{eltype(Tr),Tr}(t)
17+
end
18+
1519
@inline kappa::ZeroKernel,d::T) where {T<:Real} = zero(T)
1620

1721
"""
@@ -30,11 +34,7 @@ struct WhiteKernel{T,Tr} <: Kernel{T,Tr}
3034
end
3135
end
3236

33-
function WhiteKernel()
34-
WhiteKernel{Float64,IdentityTransform}(IdentityTransform())
35-
end
36-
37-
function WhiteKernel(t::Tr) where {Tr<:Transform}
37+
function WhiteKernel(t::Tr=IdentityTransform()) where {Tr<:Transform}
3838
WhiteKernel{eltype(Tr),Tr}(t)
3939
end
4040

src/kernels/exponential.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ end
8080

8181
function GammaExponentialKernel::T₁=1.0,gamma::T₂=2.0) where {T₁<:Real,T₂<:Real}
8282
@check_args(GammaExponentialKernel, gamma, gamma >= zero(T₂), "gamma > 0")
83-
Polynomial{T₁,ScaleTransform{T₁},T₂}(ScaleTransform(ρ),gamma)
83+
GammaExponentialKernel{T₁,ScaleTransform{T₁},T₂}(ScaleTransform(ρ),gamma)
8484
end
8585

8686
function GammaExponentialKernel::A,gamma::T₁=2.0) where {A<:AbstractVector{<:Real},T₁<:Real}
@@ -93,4 +93,4 @@ function GammaExponentialKernel(t::Tr,gamma::T₁=2.0) where {Tr<:Transform,T₁
9393
GammaExponentialKernel{eltype(Tr),Tr,T₁}(t,gamma)
9494
end
9595

96-
@inline kappa::GammaExponentialKernel, d²::Real) where {T} = exp(-^γ)
96+
@inline kappa::GammaExponentialKernel, d²::Real) where {T} = exp(-^κ.γ)

src/kernels/exponentiated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
struct ExponentiatedKernel{T,Tr} <: Kernel{T,Tr}
1111
transform::Tr
1212
metric::DotProduct
13-
function ExponentiatedKernel{T}(transform::Tr) where {T,Tr<:Transform}
13+
function ExponentiatedKernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
1414
return new{T,Tr}(transform,DotProduct())
1515
end
1616
end

src/transform/scaletransform.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Scale Transform
3+
"""
14
struct ScaleTransform{T<:Union{Real,AbstractVector{<:Real}}} <: Transform
25
s::T
36
end
@@ -26,7 +29,7 @@ function transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix
2629
end
2730
_transform(t,X,obsdim)
2831
end
29-
_transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = t.s .* x
32+
transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real},obsdim::Int=defaultobs) = t.s .* x
3033
_transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 1 ? t.s'.*X : t.s .* X
3134

3235
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat,obsdim::Int=defaultobs) = t.s .* x

src/transform/transform.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export Transform, ScaleTransform, LowRankTransform, FunctionTransform, TransformChain
1+
export Transform, IdentityTransform, ScaleTransform, LowRankTransform, FunctionTransform, ChainTransform
22

33

44
abstract type Transform end
@@ -7,27 +7,27 @@ include("scaletransform.jl")
77
include("lowranktransform.jl")
88
include("functiontransform.jl")
99

10-
struct TransformChain <: Transform
10+
struct ChainTransform <: Transform
1111
transforms::Vector{Transform}
1212
end
1313

14-
Base.length(t::TransformChain) = length(t.transforms)
14+
Base.length(t::ChainTransform) = length(t.transforms)
1515

16-
function TransformChain(v::AbstractVector{<:Transform})
17-
TransformChain(v)
16+
function ChainTransform(v::AbstractVector{<:Transform})
17+
ChainTransform(v)
1818
end
1919

20-
function transform(t::TransformChain,X::T,obsdim::Int=defaultobs) where {T}
20+
function transform(t::ChainTransform,X::T,obsdim::Int=defaultobs) where {T}
2121
Xtr = copy(X)
2222
for tr in t.transforms
2323
Xtr = transform(tr,Xtr,obsdim)
2424
end
2525
return Xtr
2626
end
2727

28-
Base.:(t₁::Transform,t₂::Transform) = TransformChain([t₂,t₁])
29-
Base.:(t::Transform,tc::TransformChain) = TransformChain(vcat(tc.transforms,t))
30-
Base.:(tc::TransformChain,t::Transform) = TransformChain(vcat(t,tc.transforms))
28+
Base.:(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
29+
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t))
30+
Base.:(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms))
3131

3232
struct IdentityTransform <: Transform end
3333

test/test_distances.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Test
2+
using Distances, LinearAlgebra
3+
using KernelFunctions
4+
5+
A = rand(10,5)
6+
B = rand(20,5)
7+
@testset "Distance" begin
8+
@testset "Dot Product" begin
9+
d = KernelFunctions.DotProduct()
10+
@test diag(pairwise(d,A,dims=2)) == dot.(eachcol(A),eachcol(A))
11+
@test_throws DimensionMismatch d(rand(3),rand(4))
12+
@test d(3.0,2.0) == 6.0
13+
end
14+
@testset "Delta" begin
15+
d = KernelFunctions.Delta()
16+
@test pairwise(d,A,dims=1) == Matrix(I,size(A,1),size(A,1))
17+
@test pairwise(d,A,B,dims=1) == zeros(size(A,1),size(B,1))
18+
@test d(1,2) == 0
19+
@test d(1,1) == 1
20+
@test_throws DimensionMismatch d(rand(3),rand(4))
21+
end
22+
end

0 commit comments

Comments
 (0)