Skip to content

Commit 4789223

Browse files
committed
Implemented Kronecker matrices with Kronecker.jl
1 parent 073fdf9 commit 4789223

File tree

5 files changed

+28
-7
lines changed

5 files changed

+28
-7
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ version = "0.2.4"
55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
8+
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
11+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1012
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1113
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1214
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -15,6 +17,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1517
[compat]
1618
Compat = "2.2, 3.2"
1719
Distances = "0.8"
20+
Kronecker = "0.3.1"
1821
PDMats = "0.9"
1922
SpecialFunctions = "0.8, 0.9, 0.10"
2023
StatsFuns = "0.8, 0.9"

src/KernelFunctions.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat # Main matrix functions
3+
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa # Main matrix functions
44
export params, duplicate, set! # Helpers
55

66
export Kernel
@@ -17,6 +17,7 @@ export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransf
1717
export NystromFact, nystrom
1818

1919
using Compat
20+
using Requires
2021
using Distances, LinearAlgebra
2122
using SpecialFunctions: logabsgamma, besselk
2223
using ZygoteRules: @adjoint
@@ -41,7 +42,6 @@ for k in ["exponential","matern","polynomial","constant","rationalquad","exponen
4142
include(joinpath("kernels",k*".jl"))
4243
end
4344
include("matrix/kernelmatrix.jl")
44-
include("matrix/kernelpdmat.jl")
4545
include("kernels/kernelsum.jl")
4646
include("kernels/kernelproduct.jl")
4747
include("approximations/nystrom.jl")
@@ -50,4 +50,9 @@ include("generic.jl")
5050

5151
include("zygote_adjoints.jl")
5252

53+
function __init__()
54+
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
55+
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
56+
end
57+
5358
end

src/matrix/kernelkroeneckermat.jl renamed to src/matrix/kernelkroneckermat.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1+
using .Kronecker
2+
3+
export kernelkronmat
4+
15
function kernelkronmat(
26
κ::Kernel,
37
X::AbstractVector,
48
dims::Int
59
)
6-
@assert iskroncompatible(κ) "The kernel chosed is not compatible for kroenecker matrices (see `iskroncompatible()`)"
10+
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see `iskroncompatible()`)"
711
k = kernelmatrix(κ,reshape(X,:,1),obsdim=1)
8-
K = kron()
12+
kronecker(k,dims)
913
end
1014

1115
function kernelkronmat(
1216
κ::Kernel,
1317
X::AbstractVector{<:AbstractVector};
1418
obsdim::Int=defaultobs
1519
)
16-
@assert iskroncompatible(κ) "The kernel chosed is not compatible for kroenecker matrices"
20+
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices"
1721
Ks = kernelmatrix.(κ,X,obsdim=obsdim)
18-
K = kron(Ks)
22+
K = reduce(,Ks)
1923
end
2024

2125

src/matrix/kernelpdmat.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
export kernelpdmat
2+
13
"""
24
Compute a positive-definite matrix in the form of a `PDMat` matrix see [PDMats.jl]() with the cholesky decomposition precomputed
35
The algorithm recursively tries to add recursively a diagonal nugget until positive definiteness is achieved or that the noise is too big

test/test_kernelmatrix.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Distances, LinearAlgebra
22
using Test
33
using KernelFunctions
44
using PDMats
5-
5+
using Kronecker
66
dims = [10,5]
77

88
A = rand(dims...)
@@ -67,4 +67,11 @@ k = SqExponentialKernel()
6767
# @test_throws ErrorException kernelpdmat(k,ones(100,100),obsdim=obsdim)
6868
end
6969
end
70+
@testset "Kronecker" begin
71+
x = range(0,1,length=10)
72+
X = vcat(collect.(Iterators.product(x,x))'...)
73+
@test all(collect(kernelkronmat(k,collect(x),2)).≈kernelmatrix(k,X,obsdim=1))
74+
@test all(collect(kernelkronmat(k,[x,x])).≈kernelmatrix(k,X,obsdim=1))
75+
@test_throws AssertionError kernelkronmat(LinearKernel(),collect(x),2)
76+
end
7077
end

0 commit comments

Comments
 (0)