Skip to content

Commit 0f04a70

Browse files
authored
Merge pull request #30 from IsakFalk/isakfalk/nystrom
Nystrom Approximation
2 parents 4a16377 + bdbe6e0 commit 0f04a70

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1010
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
11+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1112
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1213
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1314

src/KernelFunctions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@ export KernelSum, KernelProduct
1414

1515
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1616

17+
export NystromFact, nystrom
18+
1719
using Compat
1820
using Distances, LinearAlgebra
1921
using SpecialFunctions: logabsgamma, besselk
2022
using ZygoteRules: @adjoint
2123
using StatsFuns: logtwo
24+
using StatsBase
2225
using PDMats: PDMat
2326

2427
const defaultobs = 2
@@ -41,6 +44,7 @@ include("matrix/kernelmatrix.jl")
4144
include("matrix/kernelpdmat.jl")
4245
include("kernels/kernelsum.jl")
4346
include("kernels/kernelproduct.jl")
47+
include("approximations/nystrom.jl")
4448

4549
include("generic.jl")
4650

src/approximations/nystrom.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Following the algorithm by William and Seeger, 2001
2+
# Cs is equivalent to X_mm and C to X_mn
3+
4+
function sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs)
5+
0 < r <= 1 || throw(ArgumentError("Sample rate `r` must be in range (0,1]"))
6+
n = size(X, obsdim)
7+
m = ceil(Int, n*r)
8+
S = StatsBase.sample(1:n, m; replace=false, ordered=true)
9+
return S
10+
end
11+
12+
function nystrom_sample(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Integer=defaultobs)
13+
obsdim [1, 2] || throw(ArgumentError("`obsdim` should be 1 or 2 (see docs of kernelmatrix))"))
14+
Xₘ = obsdim == 1 ? X[S, :] : X[:, S]
15+
C = k(Xₘ, X; obsdim=obsdim)
16+
Cs = C[:, S]
17+
return (C, Cs)
18+
end
19+
20+
function nystrom_pinv!(Cs::Matrix{T}, tol::T=eps(T)*size(Cs,1)) where {T<:Real}
21+
# Compute eigendecomposition of sampled component of K
22+
QΛQᵀ = LinearAlgebra.eigen!(LinearAlgebra.Symmetric(Cs))
23+
24+
# Solve for D = Λ^(-1/2) (pseudo inverse - use tolerance from before factorization)
25+
D = QΛQᵀ.values
26+
λ_tol = maximum(D)*tol
27+
28+
for i in eachindex(D)
29+
@inbounds D[i] = abs(D[i]) <= λ_tol ? zero(T) : one(T)/sqrt(D[i])
30+
end
31+
32+
# Scale eigenvectors by D
33+
Q = QΛQᵀ.vectors
34+
QD = LinearAlgebra.rmul!(Q, LinearAlgebra.Diagonal(D)) # Scales column i of Q by D[i]
35+
36+
# W := (QD)(QD)ᵀ = (QΛQᵀ)^(-1) (pseudo inverse)
37+
W = QD*QD'
38+
39+
# Symmetrize W
40+
return LinearAlgebra.copytri!(W, 'U')
41+
end
42+
43+
@doc raw"""
44+
NystromFact
45+
46+
Type for storing a Nystrom factorization. The factorization contains two fields: `W` and
47+
`C`, two matrices satisfying:
48+
```math
49+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
50+
```
51+
"""
52+
struct NystromFact{T<:Real}
53+
W::Matrix{T}
54+
C::Matrix{T}
55+
end
56+
57+
function NystromFact(W::Matrix{<:Real}, C::Matrix{<:Real})
58+
T = Base.promote_eltypeof(W, C)
59+
return NystromFact(convert(Matrix{T}, W), convert(Matrix{T}, C))
60+
end
61+
62+
@doc raw"""
63+
nystrom(k::Kernel, X::Matrix, S::Vector; obsdim::Int=defaultobs)
64+
65+
Computes a factorization of Nystrom approximation of the square kernel matrix of data
66+
matrix `X` with respect to kernel `k`. Returns a `NystromFact` struct which stores a
67+
Nystrom factorization satisfying:
68+
```math
69+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
70+
```
71+
"""
72+
function nystrom(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Int=defaultobs)
73+
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
74+
W = nystrom_pinv!(Cs)
75+
return NystromFact(W, C)
76+
end
77+
78+
@doc raw"""
79+
nystrom(k::Kernel, X::Matrix, r::Real; obsdim::Int=defaultobs)
80+
81+
Computes a factorization of Nystrom approximation of the square kernel matrix of data
82+
matrix `X` with respect to kernel `k` using a sample ratio of `r`.
83+
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
84+
```math
85+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
86+
```
87+
"""
88+
function nystrom(k::Kernel, X::AbstractMatrix, r::Real; obsdim::Int=defaultobs)
89+
S = sampleindex(X, r; obsdim=obsdim)
90+
return nystrom(k, X, S; obsdim=obsdim)
91+
end
92+
93+
"""
94+
nystrom(CᵀWC::NystromFact)
95+
96+
Compute the approximate kernel matrix based on the Nystrom factorization.
97+
"""
98+
function kernelmatrix(CᵀWC::NystromFact{<:Real})
99+
W = CᵀWC.W
100+
C = CᵀWC.C
101+
return C'*W*C
102+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Random
55

66
@testset "KernelFunctions" begin
77
include("test_kernelmatrix.jl")
8+
include("test_approximations.jl")
89
include("test_constructors.jl")
910
# include("test_AD.jl")
1011
include("test_transform.jl")

test/test_approximations.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Distances, LinearAlgebra
2+
using Test
3+
using KernelFunctions
4+
5+
dims = [10,5]
6+
X = rand(dims...)
7+
k = SqExponentialKernel()
8+
@testset "Kernel Matrix Approximations" begin
9+
@testset "Nystrom" begin
10+
for obsdim in [1, 2]
11+
@test kernelmatrix(k, X; obsdim=obsdim) kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim))
12+
@test kernelmatrix(k, X; obsdim=obsdim) kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim))
13+
end
14+
end
15+
end

0 commit comments

Comments
 (0)