Skip to content

Commit 376407d

Browse files
rossviljoentheogfdevmotion
authored
Add NormalizedKernel (#274)
Co-authored-by: Théo Galy-Fajou <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 589daff commit 376407d

File tree

6 files changed

+105
-2
lines changed

6 files changed

+105
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.9.2"
3+
version = "0.9.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
88
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
99
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
10+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1011
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -23,6 +24,7 @@ ChainRulesCore = "0.9"
2324
Compat = "3.7"
2425
CompositionsBase = "0.1"
2526
Distances = "0.10"
27+
FillArrays = "0.10, 0.11"
2628
Functors = "0.1"
2729
Requires = "1.0.1"
2830
SpecialFunctions = "0.8, 0.9, 0.10, 1"

docs/src/kernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ ScaledKernel
123123
KernelSum
124124
KernelProduct
125125
KernelTensorProduct
126+
NormalizedKernel
126127
```
127128

128129
## Multi-output Kernels

src/KernelFunctions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export RationalQuadraticKernel, GammaRationalQuadraticKernel
1717
export GaborKernel, PiecewisePolynomialKernel
1818
export PeriodicKernel, NeuralNetworkKernel
1919
export KernelSum, KernelProduct, KernelTensorProduct
20-
export TransformedKernel, ScaledKernel
20+
export TransformedKernel, ScaledKernel, NormalizedKernel
2121

2222
export Transform,
2323
SelectTransform,
@@ -53,6 +53,7 @@ using ZygoteRules: ZygoteRules
5353
using StatsFuns: logtwo, twoπ
5454
using StatsBase
5555
using TensorCore
56+
using FillArrays
5657

5758
abstract type Kernel end
5859
abstract type SimpleKernel <: Kernel end
@@ -89,6 +90,7 @@ include(joinpath("basekernels", "wiener.jl"))
8990

9091
include(joinpath("kernels", "transformedkernel.jl"))
9192
include(joinpath("kernels", "scaledkernel.jl"))
93+
include(joinpath("kernels", "normalizedkernel.jl"))
9294
include(joinpath("matrix", "kernelmatrix.jl"))
9395
include(joinpath("kernels", "kernelsum.jl"))
9496
include(joinpath("kernels", "kernelproduct.jl"))

src/kernels/normalizedkernel.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
NormalizedKernel(k::Kernel)
3+
4+
A normalized kernel derived from `k`.
5+
6+
# Definition
7+
8+
For inputs ``x, x'``, the normalized kernel ``\\widetilde{k}`` derived from
9+
kernel ``k`` is defined as
10+
```math
11+
\\widetilde{k}(x, x'; k) = \\frac{k(x, x')}{\\sqrt{k(x, x) k(x', x')}}.
12+
```
13+
"""
14+
struct NormalizedKernel{Tk<:Kernel} <: Kernel
15+
kernel::Tk
16+
end
17+
18+
@functor NormalizedKernel
19+
20+
::NormalizedKernel)(x, y) = κ.kernel(x, y) / sqrt.kernel(x, x) * κ.kernel(y, y))
21+
22+
function kernelmatrix::NormalizedKernel, x::AbstractVector, y::AbstractVector)
23+
return kernelmatrix.kernel, x, y) ./
24+
sqrt.(
25+
kernelmatrix_diag.kernel, x) .* permutedims(kernelmatrix_diag.kernel, y))
26+
)
27+
end
28+
29+
function kernelmatrix::NormalizedKernel, x::AbstractVector)
30+
x_diag = kernelmatrix_diag.kernel, x)
31+
return kernelmatrix.kernel, x) ./ sqrt.(x_diag .* permutedims(x_diag))
32+
end
33+
34+
function kernelmatrix_diag::NormalizedKernel, x::AbstractVector)
35+
first_x = first(x)
36+
return Ones{typeof(κ(first_x, first_x))}(length(x))
37+
end
38+
39+
function kernelmatrix_diag::NormalizedKernel, x::AbstractVector, y::AbstractVector)
40+
return kernelmatrix_diag.kernel, x, y) ./
41+
sqrt.(kernelmatrix_diag.kernel, x) .* kernelmatrix_diag.kernel, y))
42+
end
43+
44+
function kernelmatrix!(
45+
K::AbstractMatrix, κ::NormalizedKernel, x::AbstractVector, y::AbstractVector
46+
)
47+
kernelmatrix!(K, κ.kernel, x, y)
48+
K ./=
49+
sqrt.(kernelmatrix_diag.kernel, x) .* permutedims(kernelmatrix_diag.kernel, y)))
50+
return K
51+
end
52+
53+
function kernelmatrix!(K::AbstractMatrix, κ::NormalizedKernel, x::AbstractVector)
54+
kernelmatrix!(K, κ.kernel, x)
55+
x_diag = kernelmatrix_diag.kernel, x)
56+
K ./= sqrt.(x_diag .* permutedims(x_diag))
57+
return K
58+
end
59+
60+
function kernelmatrix_diag!(
61+
K::AbstractVector, κ::NormalizedKernel, x::AbstractVector, y::AbstractVector
62+
)
63+
kernelmatrix_diag!(K, κ.kernel, x, y)
64+
K ./= sqrt.(kernelmatrix_diag.kernel, x) .* kernelmatrix_diag.kernel, y))
65+
return K
66+
end
67+
68+
function kernelmatrix_diag!(K::AbstractVector, κ::NormalizedKernel, x::AbstractVector)
69+
first_x = first(x)
70+
return fill!(K, κ(first_x, first_x))
71+
end
72+
73+
Base.show(io::IO, κ::NormalizedKernel) = printshifted(io, κ, 0)
74+
75+
function printshifted(io::IO, κ::NormalizedKernel, shift::Int)
76+
println(io, "Normalized Kernel:")
77+
for _ in 1:(shift + 1)
78+
print(io, "\t")
79+
end
80+
return printshifted(io, κ.kernel, shift + 1)
81+
end

test/kernels/normalizedkernel.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
@testset "normalizedkernel" begin
2+
rng = MersenneTwister(123456)
3+
x = randn(rng)
4+
y = randn(rng)
5+
6+
k = 4 * SqExponentialKernel()
7+
kn = NormalizedKernel(k)
8+
@test kn(x, y) == k(x, y) / sqrt(k(x, x) * k(y, y))
9+
@test kn(x, x) one(x) atol = 1e-5
10+
11+
# Standardised tests.
12+
TestUtils.test_interface(kn, Float64)
13+
test_ADs(x -> NormalizedKernel(exp(x[1]) * SqExponentialKernel()), rand(1))
14+
15+
test_params(kn, k)
16+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ include("test_utils.jl")
123123
include(joinpath("kernels", "overloads.jl"))
124124
include(joinpath("kernels", "scaledkernel.jl"))
125125
include(joinpath("kernels", "transformedkernel.jl"))
126+
include(joinpath("kernels", "normalizedkernel.jl"))
126127
end
127128
@info "Ran tests on Kernel"
128129

0 commit comments

Comments
 (0)