Skip to content
1 change: 1 addition & 0 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ TransformedKernel
ScaledKernel
KernelSum
KernelProduct
KernelTensorSum
KernelTensorProduct
NormalizedKernel
```
Expand Down
4 changes: 3 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ export LinearKernel, PolynomialKernel
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
export PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct, KernelTensorProduct
export KernelSum, KernelProduct, KernelTensorSum, KernelTensorProduct
export TransformedKernel, ScaledKernel, NormalizedKernel
export GibbsKernel
export ⊕

export Transform,
SelectTransform,
Expand Down Expand Up @@ -108,6 +109,7 @@ include("kernels/normalizedkernel.jl")
include("matrix/kernelmatrix.jl")
include("kernels/kernelsum.jl")
include("kernels/kernelproduct.jl")
include("kernels/kerneltensorsum.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/overloads.jl")
include("kernels/neuralkernelnetwork.jl")
Expand Down
110 changes: 110 additions & 0 deletions src/kernels/kerneltensorsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
KernelTensorSum

Tensor sum of kernels.

# Definition

For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor
sum of kernels ``k_1, \\ldots, k_n`` is defined as
```math
k(x, x'; k_1, \\ldots, k_n) = \\sum_{i=1}^n k_i(x_i, x'_i).
```

# Construction

The simplest way to specify a `KernelTensorSum` is to use the `⊕` operator (can be typed by `\\oplus<tab>`).
```jldoctest tensorsum
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);

julia> kernelmatrix(k1 ⊕ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) + kernelmatrix(k2, X[:, 2])
true
```

You can also specify a `KernelTensorSum` by providing kernels as individual arguments
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
individual arguments guarantees that `KernelTensorSum` is concretely typed but might
lead to large compilation times if the number of kernels is large.
```jldoctest tensorsum
julia> KernelTensorSum(k1, k2) == k1 ⊕ k2
true

julia> KernelTensorSum((k1, k2)) == k1 ⊕ k2
true

julia> KernelTensorSum([k1, k2]) == k1 ⊕ k2
true
```
"""
struct KernelTensorSum{K} <: Kernel
kernels::K
end

function KernelTensorSum(kernel::Kernel, kernels::Kernel...)
return KernelTensorSum((kernel, kernels...))

Check warning on line 44 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L43-L44

Added lines #L43 - L44 were not covered by tests
end

@functor KernelTensorSum

Base.length(kernel::KernelTensorSum) = length(kernel.kernels)

Check warning on line 49 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L49

Added line #L49 was not covered by tests

function (kernel::KernelTensorSum)(x, y)
if !((nx = length(x)) == (ny = length(y)) == (nkernels = length(kernel)))
throw(

Check warning on line 53 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L51-L53

Added lines #L51 - L53 were not covered by tests
DimensionMismatch(
"number of kernels ($nkernels) and number of features (x=$nx, y=$ny) are not consistent",
),
)
end
return sum(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))

Check warning on line 59 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L59

Added line #L59 was not covered by tests
end

function validate_domain(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
return (dx = dim(x)) == (dy = dim(y)) == (nkernels = length(k)) || error(

Check warning on line 63 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L62-L63

Added lines #L62 - L63 were not covered by tests
"number of kernels ($nkernels) and group of features (x=$dx), y=$dy) are not consistent",
)
end

function validate_domain(k::KernelTensorSum, x::AbstractVector)
return validate_domain(k, x, x)

Check warning on line 69 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
end

function kernelmatrix(k::KernelTensorSum, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, +, k.kernels, slices(x))

Check warning on line 74 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L72-L74

Added lines #L72 - L74 were not covered by tests
end

function kernelmatrix(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
validate_domain(k, x, y)
return mapreduce(kernelmatrix, +, k.kernels, slices(x), slices(y))

Check warning on line 79 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L77-L79

Added lines #L77 - L79 were not covered by tests
end

function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x))

Check warning on line 84 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L82-L84

Added lines #L82 - L84 were not covered by tests
end

function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
validate_domain(k, x, y)
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x), slices(y))

Check warning on line 89 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L87-L89

Added lines #L87 - L89 were not covered by tests
end

function Base.:(==)(x::KernelTensorSum, y::KernelTensorSum)
return (

Check warning on line 93 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L92-L93

Added lines #L92 - L93 were not covered by tests
length(x.kernels) == length(y.kernels) &&
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
)
end

Base.show(io::IO, kernel::KernelTensorSum) = printshifted(io, kernel, 0)

Check warning on line 99 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L99

Added line #L99 was not covered by tests

function printshifted(io::IO, kernel::KernelTensorSum, shift::Int)
print(io, "Tensor sum of ", length(kernel), " kernels:")
for k in kernel.kernels
print(io, "\n")
for _ in 1:(shift + 1)
print(io, "\t")
end
printshifted(io, k, shift + 2)
end

Check warning on line 109 in src/kernels/kerneltensorsum.jl

View check run for this annotation

Codecov / codecov/patch

src/kernels/kerneltensorsum.jl#L101-L109

Added lines #L101 - L109 were not covered by tests
end
4 changes: 4 additions & 0 deletions src/kernels/overloads.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
function tensor_sum end
const ⊕ = tensor_sum

for (M, op, T) in (
(:Base, :+, :KernelSum),
(:Base, :*, :KernelProduct),
(:TensorCore, :tensor, :KernelTensorProduct),
(:KernelFunctions, :⊕, :KernelTensorSum),
)
@eval begin
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)
Expand Down
67 changes: 67 additions & 0 deletions test/kernels/kerneltensorsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
@testset "kerneltensorsum" begin
rng = MersenneTwister(123456)
u1 = rand(rng, 10)
u2 = rand(rng, 10)
v1 = rand(rng, 5)
v2 = rand(rng, 5)

# kernels
k1 = SqExponentialKernel()
k2 = ExponentialKernel()
kernel1 = KernelTensorSum(k1, k2)
kernel2 = KernelTensorSum([k1, k2])

@test kernel1 == kernel2
@test kernel1.kernels == (k1, k2) === KernelTensorSum((k1, k2)).kernels
for (_k1, _k2) in Iterators.product(
(k1, KernelTensorSum((k1,)), KernelTensorSum([k1])),
(k2, KernelTensorSum((k2,)), KernelTensorSum([k2])),
)
@test kernel1 == _k1 ⊕ _k2
end
@test length(kernel1) == length(kernel2) == 2
@test string(kernel1) == (
"Independent sum of 2 kernels:\n" *
"\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" *
"\tExponential Kernel (metric = Euclidean(0.0))"
)
@test_throws DimensionMismatch kernel1(rand(3), rand(3))

@testset "val" begin
for (x, y) in (((v1, u1), (v2, u2)), ([v1, u1], [v2, u2]))
val = k1(x[1], y[1]) + k2(x[2], y[2])

@test kernel1(x, y) == kernel2(x, y) == val
end
end

# Standardised tests.
TestUtils.test_interface(kernel1, ColVecs{Float64})
TestUtils.test_interface(kernel1, RowVecs{Float64})
TestUtils.test_interface(
KernelTensorSum(WhiteKernel(), ConstantKernel(; c=1.1)), ColVecs{String}
)
test_ADs(
x -> KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))),
rand(1);
dims=[2, 2],
)
types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}]
test_interface_ad_perf(2.1, StableRNG(123456), types) do c
KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=c))
end
test_params(KernelTensorSum(k1, k2), (k1, k2))

@testset "single kernel" begin
kernel = KernelTensorSum(k1)
@test length(kernel) == 1

@testset "eval" begin
for (x, y) in (((v1,), (v2,)), ([v1], [v2]))
val = k1(x[1], y[1])

@test kernel(x, y) == val
end
end
end
end
5 changes: 3 additions & 2 deletions test/kernels/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
k2 = SqExponentialKernel()
k3 = RationalQuadraticKernel()

for (op, T) in ((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct))
if T === KernelTensorProduct
for (op, T) in
((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct), (⊕, KernelTensorSum))
if T === KernelTensorProduct || T === KernelTensorSum
v2_1 = rand(rng, 2)
v2_2 = rand(rng, 2)
v3_1 = rand(rng, 3)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ include("test_utils.jl")
include("kernels/kernelproduct.jl")
include("kernels/kernelsum.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/kerneltensorsum.jl")
include("kernels/overloads.jl")
include("kernels/scaledkernel.jl")
include("kernels/transformedkernel.jl")
Expand Down