Skip to content

Commit 0d0cc97

Browse files
committed
Clean up and refactor
1 parent 5752b64 commit 0d0cc97

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransf
1717
export NystromFact, nystrom
1818

1919
using Compat
20+
using LinearAlgebra
2021
using Distances, LinearAlgebra
2122
using SpecialFunctions: logabsgamma, besselk
2223
using ZygoteRules: @adjoint
2324
using StatsFuns: logtwo
25+
using StatsBase
2426
using PDMats: PDMat
2527

2628
const defaultobs = 2

src/approximations/nystrom.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
using StatsBase
2-
using LinearAlgebra
31
# Following the algorithm by William and Seeger, 2001
42
# Cs is equivalent to X_mm and C to X_mn
53

6-
function sampleindex(X::AbstractMatrix, r::AbstractFloat; obsdim::Integer=defaultobs)
4+
function sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs)
75
0 < r <= 1 || throw(ArgumentError("Sample rate `r` must be in range (0,1]"))
86
n = size(X, obsdim)
97
m = ceil(Int, n*r)
@@ -51,11 +49,16 @@ Type for storing a Nystrom factorization. The factorization contains two fields:
5149
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
5250
```
5351
"""
54-
struct NystromFact{T<:AbstractFloat}
52+
struct NystromFact{T<:Real}
5553
W::Matrix{T}
5654
C::Matrix{T}
5755
end
5856

57+
function NystromFact(W::Matrix{<:Real}, C::Matrix{<:Real})
58+
T = Base.promote_eltypeof(W, C)
59+
return NystromFact(convert(Matrix{T}, W), convert(Matrix{T}, C))
60+
end
61+
5962
@doc raw"""
6063
nystrom(k::Kernel, X::Matrix, S::Vector; obsdim::Int=defaultobs)
6164
@@ -69,8 +72,7 @@ Nystrom factorization satisfying:
6972
function nystrom(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Int=defaultobs)
7073
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
7174
W = nystrom_pinv!(Cs)
72-
T = typeof(first(W))
73-
return NystromFact{T}(W, C)
75+
return NystromFact(W, C)
7476
end
7577

7678
@doc raw"""
@@ -87,8 +89,7 @@ function nystrom(k::Kernel, X::AbstractMatrix, r::AbstractFloat; obsdim::Int=def
8789
S = sampleindex(X, r; obsdim=obsdim)
8890
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
8991
W = nystrom_pinv!(Cs)
90-
T = typeof(first(W))
91-
return NystromFact{T}(W, C)
92+
return NystromFact(W, C)
9293
end
9394

9495
"""

0 commit comments

Comments
 (0)