|
| 1 | +@doc raw""" |
| 2 | + LatentFactorMOKernel(g, e::MOKernel, A::AbstractMatrix) |
| 3 | +
|
| 4 | +The kernel associated with the Semiparametric Latent Factor Model, introduced by |
| 5 | +Seeger, Teh and Jordan (2005). |
| 6 | +
|
| 7 | +``k((x, p_x), (y, p_y)) = \Sum^{Q}_{q=1} A_{p_xq}g_q(x, y)A_{p_yq} + e((x, p_x), (y, p_y))`` |
| 8 | +
|
| 9 | +# Arguments |
| 10 | +- `g`: a collection of kernels, one for each latent process |
| 11 | +- `e`: a [`MOKernel`](@ref) - multi-output kernel |
| 12 | +- `A::AbstractMatrix`: a matrix of weights for the kernels of size `(out_dim, length(g))` |
| 13 | +
|
| 14 | +
|
| 15 | +# Reference: |
| 16 | +- [Seeger, Teh, and Jordan (2005)](https://infoscience.epfl.ch/record/161465/files/slfm-long.pdf) |
| 17 | +
|
| 18 | +""" |
| 19 | +struct LatentFactorMOKernel{Tg, Te <: MOKernel, TA <: AbstractMatrix} <: MOKernel |
| 20 | + g::Tg |
| 21 | + e::Te |
| 22 | + A::TA |
| 23 | + function LatentFactorMOKernel(g, e::MOKernel, A::AbstractMatrix) |
| 24 | + all(gi isa Kernel for gi in g) || error("`g` should be an collection of kernels") |
| 25 | + length(g) == size(A, 2) || |
| 26 | + error("Size of `A` not compatible with the given array of kernels `g`") |
| 27 | + return new{typeof(g), typeof(e), typeof(A)}(g, e, A) |
| 28 | + end |
| 29 | +end |
| 30 | + |
| 31 | +function (κ::LatentFactorMOKernel)((x, px)::Tuple{Any, Int}, (y, py)::Tuple{Any, Int}) |
| 32 | + cov_f = sum(κ.A[px, q] * κ.g[q](x, y) * κ.A[py, q] for q in 1:length(κ.g)) |
| 33 | + return cov_f + κ.e((x, px), (y, py)) |
| 34 | +end |
| 35 | + |
| 36 | +function kernelmatrix(k::LatentFactorMOKernel, x::MOInput, y::MOInput) |
| 37 | + x.out_dim == y.out_dim || error("`x` and `y` should have the same output dimension") |
| 38 | + x.out_dim == size(k.A, 1) || |
| 39 | + error("Kernel not compatible with the given multi-output inputs") |
| 40 | + |
| 41 | + # Weights matrix ((out_dim x out_dim) x length(k.g)) |
| 42 | + W = [col * col' for col in eachcol(k.A)] |
| 43 | + |
| 44 | + # Latent kernel matrix ((N x N) x length(k.g)) |
| 45 | + H = [gi.(x.x, permutedims(y.x)) for gi in k.g] |
| 46 | + |
| 47 | + # Weighted latent kernel matrix ((N*out_dim) x (N*out_dim)) |
| 48 | + W_H = sum(kron(Wi, Hi) for (Wi, Hi) in zip(W, H)) |
| 49 | + |
| 50 | + return W_H .+ kernelmatrix(k.e, x, y) |
| 51 | +end |
| 52 | + |
| 53 | +function Base.show(io::IO, k::LatentFactorMOKernel) |
| 54 | + print(io, "Semi-parametric Latent Factor Multi-Output Kernel") |
| 55 | +end |
| 56 | + |
| 57 | +function Base.show(io::IO, ::MIME"text/plain", k::LatentFactorMOKernel) |
| 58 | + print(io, "Semi-parametric Latent Factor Multi-Output Kernel\n\tgᵢ: ") |
| 59 | + join(io, k.g, "\n\t\t") |
| 60 | + print(io, "\n\teᵢ: ") |
| 61 | + join(io, k.e, "\n\t\t") |
| 62 | +end |
0 commit comments