Skip to content

Commit eeaf1c6

Browse files
committed
Implement Nystrom approximation in nystrom.jl
1 parent 58d42f5 commit eeaf1c6

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1010
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
11+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1112
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1213
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1314

src/KernelFunctions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ export KernelSum, KernelProduct
1414

1515
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1616

17+
export NystromFact, nystrom
18+
1719
using Compat
1820
using Distances, LinearAlgebra
1921
using SpecialFunctions: logabsgamma, besselk
@@ -41,6 +43,7 @@ include("matrix/kernelmatrix.jl")
4143
include("matrix/kernelpdmat.jl")
4244
include("kernels/kernelsum.jl")
4345
include("kernels/kernelproduct.jl")
46+
include("approximations/nystrom.jl")
4447

4548
include("generic.jl")
4649

src/approximations/nystrom.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
using StatsBase
2+
using LinearAlgebra
3+
# Following the algorithm by William and Seeger, 2001
4+
# Cs is equivalent to X_mm and C to X_mn
5+
6+
function sampleindex(X::AbstractMatrix, r::AbstractFloat; obsdim::Integer=defaultobs)
7+
0 < r <= 1 || throw(ArgumentError("Sample rate `r` must be in range (0,1]"))
8+
n = size(X, obsdim)
9+
m = ceil(Int, n*r)
10+
S = StatsBase.sample(collect(1:n), m; replace=false, ordered=true)
11+
return S
12+
end
13+
14+
function nystrom_sample(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Integer=defaultobs)
15+
obsdim [1, 2] || throw(ArgumentError("`obsdim` should be 1 or 2 (see docs of kernelmatrix))"))
16+
Xₘ = obsdim == 1 ? getindex(X, S, :) : getindex(X, :, S)
17+
C = k(Xₘ, X; obsdim=obsdim)
18+
Cs = getindex(C, :, S)
19+
return (C, Cs)
20+
end
21+
22+
function nystrom_pinv!(Cs::Matrix{T}, tol::T=eps(T)*size(Cs,1)) where {T<:AbstractFloat}
23+
# Compute eigendecomposition of sampled component of K
24+
QΛQᵀ = LinearAlgebra.eigen!(LinearAlgebra.Symmetric(Cs))
25+
26+
# Solve for D = Λ^(-1/2) (pseudo inverse - use tolerance from before factorization)
27+
D = QΛQᵀ.values
28+
λ_tol = maximum(D)*tol
29+
30+
for i in eachindex(D)
31+
@inbounds D[i] = abs(D[i]) <= λ_tol ? zero(T) : one(T)/sqrt(D[i])
32+
end
33+
34+
# Scale eigenvectors by D
35+
Q = QΛQᵀ.vectors
36+
QD = LinearAlgebra.rmul!(Q, LinearAlgebra.Diagonal(D)) # Scales column i of Q by D[i]
37+
38+
# W := (QD)(QD)ᵀ = (QΛQᵀ)^(-1) (pseudo inverse)
39+
W = QD*QD'
40+
41+
# Symmetrize W
42+
return LinearAlgebra.copytri!(W, 'U')
43+
end
44+
45+
"""
46+
NystromFact
47+
48+
Type for storing a Nystrom factorization. The factorization contains two fields: `W` and
49+
`C`, two matrices satisfying:
50+
```math
51+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
52+
```
53+
"""
54+
struct NystromFact{T<:AbstractFloat}
55+
W::Matrix{T}
56+
C::Matrix{T}
57+
end
58+
59+
@doc raw"""
60+
nystrom(k::Kernel, X::Matrix, S::Vector; obsdim::Int=defaultobs)
61+
62+
Computes a factorization of Nystrom approximation of the square kernel matrix of data
63+
matrix `X` with respect to kernel `k`. Returns a `NystromFact` struct which stores a
64+
Nystrom factorization satisfying:
65+
```math
66+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
67+
```
68+
"""
69+
function nystrom(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Int=defaultobs)
70+
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
71+
W = nystrom_pinv!(Cs)
72+
T = typeof(first(W))
73+
return NystromFact{T}(W, C)
74+
end
75+
76+
@doc raw"""
77+
nystrom(k::Kernel, X::Matrix, r::AbstractFloat; obsdim::Int=defaultobs)
78+
79+
Computes a factorization of Nystrom approximation of the square kernel matrix of data
80+
matrix `X` with respect to kernel `k` using a sample ratio of `r`.
81+
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
82+
```math
83+
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
84+
```
85+
"""
86+
function nystrom(k::Kernel, X::AbstractMatrix, r::AbstractFloat; obsdim::Int=defaultobs)
87+
S = sampleindex(X, r; obsdim=obsdim)
88+
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
89+
W = nystrom_pinv!(Cs)
90+
T = typeof(first(W))
91+
return NystromFact{T}(W, C)
92+
end
93+
94+
"""
95+
nystrom(CᵀWC::NystromFact)
96+
97+
Compute the approximate kernel matrix based on the Nystrom factorization.
98+
"""
99+
function kernelmatrix(CᵀWC::NystromFact{<:AbstractFloat})
100+
W = CᵀWC.W
101+
C = CᵀWC.C
102+
return C'*W*C
103+
end

0 commit comments

Comments
 (0)