Skip to content

Commit d8ec7b9

Browse files
committed
Defining the new kernelmatrix methods
1 parent 3554faa commit d8ec7b9

File tree

3 files changed

+155
-94
lines changed

3 files changed

+155
-94
lines changed

src/kernels/transformedkernel.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,17 @@ function printshifted(io::IO, κ::TransformedKernel, shift::Int)
3939
printshifted(io, κ.kernel, shift)
4040
print(io,"\n" * ("\t" ^ (shift + 1)) * "- $(κ.transform)")
4141
end
42+
43+
# Kernel matrix operations
44+
45+
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
46+
kernelmatrix!(K, kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
47+
48+
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs) =
49+
kernelmatrix!(K, kernel(κ), apply.transform, X, obsdim = obsdim), apply.transform, Y, obsdim = obsdim), obsdim = obsdim)
50+
51+
kernelmatrix::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
52+
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
53+
54+
kernelmatrix::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs) =
55+
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), apply.transform, Y, obsdim = obsdim), obsdim = obsdim)

src/matrix/kernelmatrix.jl

Lines changed: 137 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -6,59 +6,82 @@ In-place version of [`kernelmatrix`](@ref) where pre-allocated matrix `K` will b
66
"""
77
kernelmatrix!
88

9-
109
function kernelmatrix!(
11-
K::AbstractMatrix,
12-
κ::Kernel,
13-
X::AbstractMatrix;
14-
obsdim::Int = defaultobs
15-
)
16-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
17-
if !check_dims(K,X,X,feature_dim(obsdim),obsdim)
18-
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
19-
end
20-
map!(x->kappa(κ,x),K,pairwise(metric(κ),X,dims=obsdim))
10+
K::AbstractMatrix,
11+
κ::SimpleKernel,
12+
X::AbstractMatrix;
13+
obsdim::Int = defaultobs,
14+
)
15+
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
16+
if !check_dims(K, X, X, feature_dim(obsdim), obsdim)
17+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
18+
end
19+
map!(x -> kappa(κ, x), K, pairwise(metric(κ), X, dims = obsdim))
2120
end
2221

23-
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
24-
kernelmatrix!(K, kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
22+
function kernelmatrix!(
23+
K::AbstractMatrix,
24+
κ::BaseKernel,
25+
X::AbstractMatrix;
26+
obsdim::Int = defaultobs
27+
)
28+
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
29+
if obsdim == 1
30+
@compat kernelmatrix!(K, κ, ColVecs(X))
31+
else
32+
@compat kernelmatrix!(K, κ, RowVecs(X))
33+
end
34+
end
2535

2636
function kernelmatrix!(
27-
K::AbstractMatrix,
28-
κ::Kernel,
29-
X::AbstractMatrix,
30-
Y::AbstractMatrix;
31-
obsdim::Int = defaultobs
32-
)
33-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
34-
if !check_dims(K,X,Y,feature_dim(obsdim),obsdim)
35-
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"))
36-
end
37-
map!(x->kappa(κ,x),K,pairwise(metric(κ),X,Y,dims=obsdim))
37+
K::AbstractMatrix,
38+
κ::BaseKernel,
39+
X::AbstractVector
40+
)
41+
if !check_dims(K, X, X, feature_dim(obsdim), obsdim)
42+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
43+
end
44+
map!(κ, K, X, X')
3845
end
3946

40-
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs) =
41-
kernelmatrix!(K, kernel(κ), apply.transform, X, obsdim = obsdim), apply.transform, Y, obsdim = obsdim), obsdim = obsdim)
47+
function kernelmatrix!(
48+
K::AbstractMatrix,
49+
κ::SimpleKernel,
50+
X::AbstractMatrix,
51+
Y::AbstractMatrix;
52+
obsdim::Int = defaultobs,
53+
)
54+
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
55+
if !check_dims(K, X, Y, feature_dim(obsdim), obsdim)
56+
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"))
57+
end
58+
map!(κ, K, pairwise(metric(κ), X, Y, dims = obsdim))
59+
end
4260

43-
## Apply kernel on two reals ##
44-
function _kernel::Kernel, x::Real, y::Real)
45-
_kernel(κ, [x], [y])
61+
function kernelmatrix!(
62+
K::AbstractMatrix,
63+
κ::BaseKernel,
64+
X::AbstractMatrix,
65+
Y::AbstractMatrix;
66+
obsdim::Int = defaultobs
67+
)
68+
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
69+
if obsdim == 1
70+
@compat kernelmatrix!(K, κ, ColVecs(X), ColVecs(Y))
71+
else
72+
@compat kernelmatrix!(K, κ, RowVecs(X), RowVecs(Y))
73+
end
4674
end
4775

48-
## Apply kernel on two vectors ##
49-
function _kernel(
50-
κ::Kernel,
51-
x::AbstractVector,
52-
y::AbstractVector;
53-
obsdim::Int = defaultobs
76+
function kernelmatrix!(
77+
K::AbstractMatrix,
78+
κ::BaseKernel,
79+
X::AbstractVector,
80+
Y::AbstractVector
5481
)
55-
@assert length(x) == length(y) "x and y don't have the same dimension!"
56-
kappa(κ, evaluate(metric(κ),x,y))
82+
map!(K, κ, X, Y')
5783
end
5884

59-
_kernel::TransformedKernel, x::AbstractVector, y::AbstractVector; obsdim::Int = defaultobs) =
60-
_kernel(kernel(κ), apply.transform, x), apply.transform, y), obsdim = obsdim)
61-
6285
"""
6386
kernelmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
6487
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int = 2)
@@ -70,42 +93,52 @@ Calculate the kernel matrix of `X` (and `Y`) with respect to kernel `κ`.
7093
kernelmatrix
7194

7295
function kernelmatrix(
73-
κ::Kernel,
74-
X::AbstractVector{<:Real};
75-
obsdim::Int=defaultobs
76-
)
77-
kernelmatrix(κ,reshape(X,1,:),obsdim=2)
96+
κ::Kernel,
97+
X::AbstractVector{<:Real};
98+
obsdim::Int = defaultobs,
99+
)
100+
kernelmatrix(κ, reshape(X, 1, :), obsdim = 2)
78101
end
79102

80-
function kernelmatrix(
81-
κ::Kernel,
82-
X::AbstractMatrix;
83-
obsdim::Int = defaultobs
84-
)
85-
K = map(x->kappa(κ,x),pairwise(metric(κ),X,dims=obsdim))
103+
function kernelmatrix::Kernel, X::AbstractVector)
104+
kernelmatrix(κ, X, X) #TODO Can be optimized later
105+
end
106+
107+
function kernelmatrix::Kernel, X::AbstractVector, Y::AbstractVector)
108+
κ.(X, Y')
109+
end
110+
111+
112+
113+
function kernelmatrix::SimpleKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
114+
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
115+
K = map(x -> kappa(κ, x), pairwise(metric(κ), X, dims = obsdim))
86116
end
87117

88-
kernelmatrix::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
89-
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
118+
function kernelmatrix::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs)
119+
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
120+
if obsdim == 1
121+
kernelmatrix(κ, ColVecs(X))
122+
else
123+
kernelmatrix(κ, RowVecs(X))
124+
end
125+
end
90126

91127
function kernelmatrix(
92-
κ::Kernel,
93-
X::AbstractMatrix,
94-
Y::AbstractMatrix;
95-
obsdim=defaultobs
96-
)
97-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
98-
if !check_dims(X,Y,feature_dim(obsdim),obsdim)
128+
κ::SimpleKernel,
129+
X::AbstractMatrix,
130+
Y::AbstractMatrix;
131+
obsdim = defaultobs,
132+
)
133+
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
134+
if !check_dims(X, Y, feature_dim(obsdim), obsdim)
99135
throw(DimensionMismatch("X $(size(X)) and Y $(size(Y)) do not have the same number of features on the dimension : $(feature_dim(obsdim))"))
100136
end
101-
_kernelmatrix(κ,X,Y,obsdim)
137+
_kernelmatrix(κ, X, Y, obsdim)
102138
end
103139

104140
@inline _kernelmatrix::SimpleKernel, X, Y, obsdim) =
105-
map(x -> kappa(κ, x), pairwise(metric(κ), X, Y, dims = obsdim))
106-
107-
kernelmatrix::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs) =
108-
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), apply.transform, Y, obsdim = obsdim), obsdim = obsdim)
141+
map(x -> kappa(κ, x), pairwise(metric(κ), X, Y, dims = obsdim))
109142

110143
"""
111144
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
@@ -115,16 +148,20 @@ Calculate the diagonal matrix of `X` with respect to kernel `κ`
115148
`obsdim = 2` means the matrix `X` has size #dimension x #samples
116149
"""
117150
function kerneldiagmatrix(
118-
κ::Kernel,
119-
X::AbstractMatrix;
120-
obsdim::Int = defaultobs
121-
)
122-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
123-
if obsdim == 1
124-
@compat eachrow(X) .|> x-> κ(x, x) #[@views _kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
125-
elseif obsdim == 2
126-
@compat eachcol(X) .|> x-> κ(x, x) #[@views _kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
127-
end
151+
κ::Kernel,
152+
X::AbstractMatrix;
153+
obsdim::Int = defaultobs
154+
)
155+
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
156+
if obsdim == 1
157+
@compat kerneldiagmatrix(κ, ColVecs(X)) #[@views _kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
158+
elseif obsdim == 2
159+
@compat kerneldiagmatrix(κ, RowVecs(X)) #[@views _kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
160+
end
161+
end
162+
163+
function kerneldiagmatrix::Kernel, X::AbstractVector)
164+
κ.(X, X)
128165
end
129166

130167
"""
@@ -133,23 +170,31 @@ end
133170
In place version of [`kerneldiagmatrix`](@ref)
134171
"""
135172
function kerneldiagmatrix!(
136-
K::AbstractVector,
137-
κ::SimpleKernel,
138-
X::AbstractMatrix;
139-
obsdim::Int = defaultobs
140-
)
141-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
142-
if length(K) != size(X,obsdim)
143-
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
173+
K::AbstractVector,
174+
κ::Kernel,
175+
X::AbstractMatrix;
176+
obsdim::Int = defaultobs
177+
)
178+
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
179+
if length(K) != size(X,obsdim)
180+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
181+
end
182+
if obsdim == 1
183+
for i in eachindex(K)
184+
@inbounds @views K[i] = κ(X[i,:], X[i,:])
144185
end
145-
if obsdim == 1
146-
for i in eachindex(K)
147-
@inbounds @views K[i] = κ(X[i,:], X[i,:])
148-
end
149-
else
150-
for i in eachindex(K)
151-
@inbounds @views K[i] = κ(X[:,i], X[:,i])
152-
end
186+
else
187+
for i in eachindex(K)
188+
@inbounds @views K[i] = κ(X[:,i], X[:,i])
153189
end
154-
return K
190+
end
191+
return K
192+
end
193+
194+
function kerneldiagmatrix!(
195+
K::AbstractVector,
196+
κ::Kernel,
197+
X::AbstractVector
198+
)
199+
map!(κ, K, X, X)
155200
end

src/utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ end
1919
# return T <: Real ? T : Float64
2020
# end
2121

22-
check_dims(K,X,Y,featdim,obsdim) = check_dims(X,Y,featdim,obsdim) && (size(K) == (size(X,obsdim),size(Y,obsdim)))
22+
check_dims(K, X, Y, featdim, obsdim) =
23+
check_dims(X, Y, featdim, obsdim) &&
24+
(size(K) == (size(X, obsdim), size(Y, obsdim)))
2325

24-
check_dims(X,Y,featdim,obsdim) = size(X,featdim) == size(Y,featdim)
26+
check_dims(X, Y, featdim, obsdim) = size(X, featdim) == size(Y, featdim)
2527

2628

2729
feature_dim(obsdim::Int) = obsdim == 1 ? 2 : 1

0 commit comments

Comments
 (0)