Skip to content

Commit e051f51

Browse files
Adding naive LMM model (#304)
1 parent 65d7088 commit e051f51

File tree

6 files changed

+99
-2
lines changed

6 files changed

+99
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.5"
3+
version = "0.10.6"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/kernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,5 @@ MOKernel
134134
IndependentMOKernel
135135
LatentFactorMOKernel
136136
IntrinsicCoregionMOKernel
137+
LinearMixingModelKernel
137138
```

src/KernelFunctions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
3636
export ColVecs, RowVecs
3737

3838
export MOInput
39-
export IndependentMOKernel, LatentFactorMOKernel, IntrinsicCoregionMOKernel
39+
export IndependentMOKernel,
40+
LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel
4041

4142
# Reexports
4243
export tensor, , compose
@@ -106,6 +107,7 @@ include(joinpath("mokernels", "moinput.jl"))
106107
include(joinpath("mokernels", "independent.jl"))
107108
include(joinpath("mokernels", "slfm.jl"))
108109
include(joinpath("mokernels", "intrinsiccoregion.jl"))
110+
include(joinpath("mokernels", "lmm.jl"))
109111

110112
include("chainrules.jl")
111113
include("zygoterules.jl")

src/mokernels/lmm.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
@doc raw"""
2+
LinearMixingModelKernel(g, e::MOKernel, A::AbstractMatrix)
3+
4+
Kernel associated with the linear mixing model.
5+
6+
# Definition
7+
8+
For inputs ``x, x'`` and output dimensions ``p_x, p_{x'}'``, the kernel is defined as[^BPTHST]
9+
```math
10+
k\big((x, p_x), (x, p_{x'})\big) = H_{:,p_{x}}K(x, x')H_{:,p_{x'}}
11+
```
12+
where ``K(x, x') = Diag(k_1(x, x'), ..., k_m(x, x'))`` with zero off-diagonal entries.
13+
``H_{:,p_{x}}`` is the ``p_x``-th column (`p_x`-th output) of ``H \in \mathbb{R}^{m \times p}``
14+
representing ``m`` basis vectors for the ``p`` dimensional output space of ``f``.
15+
``k_1, \ldots, k_m`` are ``m`` kernels, one for each latent process, ``H`` is a
16+
mixing matrix of ``m`` basis vectors spanning the output space.
17+
18+
[^BPTHST]: Wessel P. Bruinsma, Eric Perim, Will Tebbutt, J. Scott Hosking, Arno Solin, Richard E. Turner (2020). [Scalable Exact Inference in Multi-Output Gaussian Processes](https://arxiv.org/pdf/1911.06287.pdf).
19+
"""
20+
struct LinearMixingModelKernel{Tk<:AbstractVector{<:Kernel},Th<:AbstractMatrix} <: MOKernel
21+
K::Tk
22+
H::Th
23+
end
24+
25+
function LinearMixingModelKernel(k::Kernel, H::AbstractMatrix)
26+
return LinearMixingModelKernel(Fill(k, size(H, 1)), H)
27+
end
28+
29+
function::LinearMixingModelKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int})
30+
(px > size.H, 2) || py > size.H, 2) || px < 1 || py < 1) &&
31+
error("`px` and `py` must be within the range of the number of outputs")
32+
return sum.H[i, px] * κ.K[i](x, y) * κ.H[i, py] for i in 1:length.K))
33+
end
34+
35+
function Base.show(io::IO, k::LinearMixingModelKernel)
36+
return print(io, "Linear Mixing Model Multi-Output Kernel")
37+
end
38+
39+
function Base.show(io::IO, mime::MIME"text/plain", k::LinearMixingModelKernel)
40+
print(io, "Linear Mixing Model Multi-Output Kernel. Kernels:")
41+
for k in k.K
42+
print(io, "\n\t")
43+
show(io, mime, k)
44+
end
45+
end

test/mokernels/lmm.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
@testset "lmm" begin
2+
rng = MersenneTwister(123)
3+
FDM = FiniteDifferences.central_fdm(5, 1)
4+
N = 10
5+
in_dim = 3
6+
out_dim = 6
7+
x1 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
8+
x2 = MOInput([rand(rng, in_dim) for _ in 1:N], out_dim)
9+
H = rand(4, 6)
10+
11+
k = LinearMixingModelKernel(
12+
[Matern32Kernel(), SqExponentialKernel(), FBMKernel(), Matern32Kernel()], H
13+
)
14+
@test k isa LinearMixingModelKernel
15+
@test k isa MOKernel
16+
@test k isa Kernel
17+
@test k(x1[1], x2[1]) isa Real
18+
19+
@test string(k) == "Linear Mixing Model Multi-Output Kernel"
20+
@test repr("text/plain", k) == (
21+
"Linear Mixing Model Multi-Output Kernel. Kernels:\n" *
22+
"\tMatern 3/2 Kernel (metric = Euclidean(0.0))\n" *
23+
"\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" *
24+
"\tFractional Brownian Motion Kernel (h = 0.5)\n" *
25+
"\tMatern 3/2 Kernel (metric = Euclidean(0.0))"
26+
)
27+
28+
k = LinearMixingModelKernel(SEKernel(), H)
29+
30+
@test k isa LinearMixingModelKernel
31+
@test k isa MOKernel
32+
@test k isa Kernel
33+
@test length(k.K) == 4
34+
for kernel in k.K
35+
@test isa(kernel, SEKernel)
36+
end
37+
38+
@test string(k) == "Linear Mixing Model Multi-Output Kernel"
39+
@test repr("text/plain", k) == (
40+
"Linear Mixing Model Multi-Output Kernel. Kernels:\n" *
41+
"\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" *
42+
"\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" *
43+
"\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" *
44+
"\tSquared Exponential Kernel (metric = Euclidean(0.0))"
45+
)
46+
47+
test_ADs(k)
48+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ include("test_utils.jl")
140140
include(joinpath("mokernels", "independent.jl"))
141141
include(joinpath("mokernels", "slfm.jl"))
142142
include(joinpath("mokernels", "intrinsiccoregion.jl"))
143+
include(joinpath("mokernels", "lmm.jl"))
143144
end
144145
@info "Ran tests on Multi-Output Kernels"
145146

0 commit comments

Comments
 (0)