@@ -6,59 +6,82 @@ In-place version of [`kernelmatrix`](@ref) where pre-allocated matrix `K` will b
6
6
"""
7
7
kernelmatrix!
8
8
9
-
10
9
function kernelmatrix! (
11
- K:: AbstractMatrix ,
12
- κ:: Kernel ,
13
- X:: AbstractMatrix ;
14
- obsdim:: Int = defaultobs
15
- )
16
- @assert obsdim ∈ [1 ,2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
17
- if ! check_dims (K,X,X, feature_dim (obsdim),obsdim)
18
- throw (DimensionMismatch (" Dimensions of the target array K $(size (K)) are not consistent with X $(size (X)) " ))
19
- end
20
- map! (x-> kappa (κ,x),K, pairwise (metric (κ),X, dims= obsdim))
10
+ K:: AbstractMatrix ,
11
+ κ:: SimpleKernel ,
12
+ X:: AbstractMatrix ;
13
+ obsdim:: Int = defaultobs,
14
+ )
15
+ @assert obsdim ∈ [1 , 2 ] " obsdim should be 1 or 2 (see docs of ` kernelmatrix` ))"
16
+ if ! check_dims (K, X, X, feature_dim (obsdim), obsdim)
17
+ throw (DimensionMismatch (" Dimensions of the target array K $(size (K)) are not consistent with X $(size (X)) " ))
18
+ end
19
+ map! (x -> kappa (κ, x), K, pairwise (metric (κ), X, dims = obsdim))
21
20
end
22
21
23
- kernelmatrix! (K:: AbstractMatrix , κ:: TransformedKernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs) =
24
- kernelmatrix! (K, kernel (κ), apply (κ. transform, X, obsdim = obsdim), obsdim = obsdim)
22
+ function kernelmatrix! (
23
+ K:: AbstractMatrix ,
24
+ κ:: BaseKernel ,
25
+ X:: AbstractMatrix ;
26
+ obsdim:: Int = defaultobs
27
+ )
28
+ @assert obsdim ∈ [1 , 2 ] " obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
29
+ if obsdim == 1
30
+ @compat kernelmatrix! (K, κ, ColVecs (X))
31
+ else
32
+ @compat kernelmatrix! (K, κ, RowVecs (X))
33
+ end
34
+ end
25
35
26
36
function kernelmatrix! (
27
- K:: AbstractMatrix ,
28
- κ:: Kernel ,
29
- X:: AbstractMatrix ,
30
- Y:: AbstractMatrix ;
31
- obsdim:: Int = defaultobs
32
- )
33
- @assert obsdim ∈ [1 ,2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
34
- if ! check_dims (K,X,Y,feature_dim (obsdim),obsdim)
35
- throw (DimensionMismatch (" Dimensions $(size (K)) of the target array K are not consistent with X ($(size (X)) ) and Y ($(size (Y)) )" ))
36
- end
37
- map! (x-> kappa (κ,x),K,pairwise (metric (κ),X,Y,dims= obsdim))
37
+ K:: AbstractMatrix ,
38
+ κ:: BaseKernel ,
39
+ X:: AbstractVector
40
+ )
41
+ if ! check_dims (K, X, X, feature_dim (obsdim), obsdim)
42
+ throw (DimensionMismatch (" Dimensions of the target array K $(size (K)) are not consistent with X $(size (X)) " ))
43
+ end
44
+ map! (κ, K, X, X' )
38
45
end
39
46
40
- kernelmatrix! (K:: AbstractMatrix , κ:: TransformedKernel , X:: AbstractMatrix , Y:: AbstractMatrix ; obsdim:: Int = defaultobs) =
41
- kernelmatrix! (K, kernel (κ), apply (κ. transform, X, obsdim = obsdim), apply (κ. transform, Y, obsdim = obsdim), obsdim = obsdim)
47
+ function kernelmatrix! (
48
+ K:: AbstractMatrix ,
49
+ κ:: SimpleKernel ,
50
+ X:: AbstractMatrix ,
51
+ Y:: AbstractMatrix ;
52
+ obsdim:: Int = defaultobs,
53
+ )
54
+ @assert obsdim ∈ [1 , 2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
55
+ if ! check_dims (K, X, Y, feature_dim (obsdim), obsdim)
56
+ throw (DimensionMismatch (" Dimensions $(size (K)) of the target array K are not consistent with X ($(size (X)) ) and Y ($(size (Y)) )" ))
57
+ end
58
+ map! (κ, K, pairwise (metric (κ), X, Y, dims = obsdim))
59
+ end
42
60
43
- # # Apply kernel on two reals ##
44
- function _kernel (κ:: Kernel , x:: Real , y:: Real )
45
- _kernel (κ, [x], [y])
61
+ function kernelmatrix! (
62
+ K:: AbstractMatrix ,
63
+ κ:: BaseKernel ,
64
+ X:: AbstractMatrix ,
65
+ Y:: AbstractMatrix ;
66
+ obsdim:: Int = defaultobs
67
+ )
68
+ @assert obsdim ∈ [1 , 2 ] " obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
69
+ if obsdim == 1
70
+ @compat kernelmatrix! (K, κ, ColVecs (X), ColVecs (Y))
71
+ else
72
+ @compat kernelmatrix! (K, κ, RowVecs (X), RowVecs (Y))
73
+ end
46
74
end
47
75
48
- # # Apply kernel on two vectors ##
49
- function _kernel (
50
- κ:: Kernel ,
51
- x:: AbstractVector ,
52
- y:: AbstractVector ;
53
- obsdim:: Int = defaultobs
76
+ function kernelmatrix! (
77
+ K:: AbstractMatrix ,
78
+ κ:: BaseKernel ,
79
+ X:: AbstractVector ,
80
+ Y:: AbstractVector
54
81
)
55
- @assert length (x) == length (y) " x and y don't have the same dimension!"
56
- kappa (κ, evaluate (metric (κ),x,y))
82
+ map! (K, κ, X, Y' )
57
83
end
58
84
59
- _kernel (κ:: TransformedKernel , x:: AbstractVector , y:: AbstractVector ; obsdim:: Int = defaultobs) =
60
- _kernel (kernel (κ), apply (κ. transform, x), apply (κ. transform, y), obsdim = obsdim)
61
-
62
85
"""
63
86
kernelmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
64
87
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int = 2)
@@ -70,42 +93,52 @@ Calculate the kernel matrix of `X` (and `Y`) with respect to kernel `κ`.
70
93
kernelmatrix
71
94
72
95
function kernelmatrix (
73
- κ:: Kernel ,
74
- X:: AbstractVector{<:Real} ;
75
- obsdim:: Int = defaultobs
76
- )
77
- kernelmatrix (κ,reshape (X,1 , :),obsdim= 2 )
96
+ κ:: Kernel ,
97
+ X:: AbstractVector{<:Real} ;
98
+ obsdim:: Int = defaultobs,
99
+ )
100
+ kernelmatrix (κ, reshape (X, 1 , :), obsdim = 2 )
78
101
end
79
102
80
- function kernelmatrix (
81
- κ:: Kernel ,
82
- X:: AbstractMatrix ;
83
- obsdim:: Int = defaultobs
84
- )
85
- K = map (x-> kappa (κ,x),pairwise (metric (κ),X,dims= obsdim))
103
+ function kernelmatrix (κ:: Kernel , X:: AbstractVector )
104
+ kernelmatrix (κ, X, X) # TODO Can be optimized later
105
+ end
106
+
107
+ function kernelmatrix (κ:: Kernel , X:: AbstractVector , Y:: AbstractVector )
108
+ κ .(X, Y' )
109
+ end
110
+
111
+
112
+
113
+ function kernelmatrix (κ:: SimpleKernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs)
114
+ @assert obsdim ∈ [1 , 2 ] " obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
115
+ K = map (x -> kappa (κ, x), pairwise (metric (κ), X, dims = obsdim))
86
116
end
87
117
88
- kernelmatrix (κ:: TransformedKernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs) =
89
- kernelmatrix (kernel (κ), apply (κ. transform, X, obsdim = obsdim), obsdim = obsdim)
118
+ function kernelmatrix (κ:: Kernel , X:: AbstractMatrix ; obsdim:: Int = defaultobs)
119
+ @assert obsdim ∈ [1 , 2 ] " obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
120
+ if obsdim == 1
121
+ kernelmatrix (κ, ColVecs (X))
122
+ else
123
+ kernelmatrix (κ, RowVecs (X))
124
+ end
125
+ end
90
126
91
127
function kernelmatrix (
92
- κ:: Kernel ,
93
- X:: AbstractMatrix ,
94
- Y:: AbstractMatrix ;
95
- obsdim= defaultobs
96
- )
97
- @assert obsdim ∈ [1 ,2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
98
- if ! check_dims (X,Y, feature_dim (obsdim),obsdim)
128
+ κ:: SimpleKernel ,
129
+ X:: AbstractMatrix ,
130
+ Y:: AbstractMatrix ;
131
+ obsdim = defaultobs,
132
+ )
133
+ @assert obsdim ∈ [1 , 2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
134
+ if ! check_dims (X, Y, feature_dim (obsdim), obsdim)
99
135
throw (DimensionMismatch (" X $(size (X)) and Y $(size (Y)) do not have the same number of features on the dimension : $(feature_dim (obsdim)) " ))
100
136
end
101
- _kernelmatrix (κ,X,Y, obsdim)
137
+ _kernelmatrix (κ, X, Y, obsdim)
102
138
end
103
139
104
140
@inline _kernelmatrix (κ:: SimpleKernel , X, Y, obsdim) =
105
- map (x -> kappa (κ, x), pairwise (metric (κ), X, Y, dims = obsdim))
106
-
107
- kernelmatrix (κ:: TransformedKernel , X:: AbstractMatrix , Y:: AbstractMatrix ; obsdim:: Int = defaultobs) =
108
- kernelmatrix (kernel (κ), apply (κ. transform, X, obsdim = obsdim), apply (κ. transform, Y, obsdim = obsdim), obsdim = obsdim)
141
+ map (x -> kappa (κ, x), pairwise (metric (κ), X, Y, dims = obsdim))
109
142
110
143
"""
111
144
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
@@ -115,16 +148,20 @@ Calculate the diagonal matrix of `X` with respect to kernel `κ`
115
148
`obsdim = 2` means the matrix `X` has size #dimension x #samples
116
149
"""
117
150
function kerneldiagmatrix (
118
- κ:: Kernel ,
119
- X:: AbstractMatrix ;
120
- obsdim:: Int = defaultobs
121
- )
122
- @assert obsdim ∈ [1 ,2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
123
- if obsdim == 1
124
- @compat eachrow (X) .| > x-> κ (x, x) # [@views _kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
125
- elseif obsdim == 2
126
- @compat eachcol (X) .| > x-> κ (x, x) # [@views _kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
127
- end
151
+ κ:: Kernel ,
152
+ X:: AbstractMatrix ;
153
+ obsdim:: Int = defaultobs
154
+ )
155
+ @assert obsdim ∈ [1 ,2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
156
+ if obsdim == 1
157
+ @compat kerneldiagmatrix (κ, ColVecs (X)) # [@views _kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
158
+ elseif obsdim == 2
159
+ @compat kerneldiagmatrix (κ, RowVecs (X)) # [@views _kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
160
+ end
161
+ end
162
+
163
+ function kerneldiagmatrix (κ:: Kernel , X:: AbstractVector )
164
+ κ .(X, X)
128
165
end
129
166
130
167
"""
@@ -133,23 +170,31 @@ end
133
170
In place version of [`kerneldiagmatrix`](@ref)
134
171
"""
135
172
function kerneldiagmatrix! (
136
- K:: AbstractVector ,
137
- κ:: SimpleKernel ,
138
- X:: AbstractMatrix ;
139
- obsdim:: Int = defaultobs
140
- )
141
- @assert obsdim ∈ [1 ,2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
142
- if length (K) != size (X,obsdim)
143
- throw (DimensionMismatch (" Dimensions of the target array K $(size (K)) are not consistent with X $(size (X)) " ))
173
+ K:: AbstractVector ,
174
+ κ:: Kernel ,
175
+ X:: AbstractMatrix ;
176
+ obsdim:: Int = defaultobs
177
+ )
178
+ @assert obsdim ∈ [1 ,2 ] " obsdim should be 1 or 2 (see docs of kernelmatrix))"
179
+ if length (K) != size (X,obsdim)
180
+ throw (DimensionMismatch (" Dimensions of the target array K $(size (K)) are not consistent with X $(size (X)) " ))
181
+ end
182
+ if obsdim == 1
183
+ for i in eachindex (K)
184
+ @inbounds @views K[i] = κ (X[i,:], X[i,:])
144
185
end
145
- if obsdim == 1
146
- for i in eachindex (K)
147
- @inbounds @views K[i] = κ (X[i,:], X[i,:])
148
- end
149
- else
150
- for i in eachindex (K)
151
- @inbounds @views K[i] = κ (X[:,i], X[:,i])
152
- end
186
+ else
187
+ for i in eachindex (K)
188
+ @inbounds @views K[i] = κ (X[:,i], X[:,i])
153
189
end
154
- return K
190
+ end
191
+ return K
192
+ end
193
+
194
+ function kerneldiagmatrix! (
195
+ K:: AbstractVector ,
196
+ κ:: Kernel ,
197
+ X:: AbstractVector
198
+ )
199
+ map! (κ, K, X, X)
155
200
end
0 commit comments