Skip to content

Commit 0eda1d0

Browse files
authored
Merge pull request #193 from oxinabox/ox/chillaxonthematrix
Make clustering methods accept AbstractMatrix
2 parents 152e06f + 6560ff5 commit 0eda1d0

File tree

7 files changed

+67
-19
lines changed

7 files changed

+67
-19
lines changed

src/affprop.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const _afp_default_tol = 1.0e-6
3333
const _afp_default_display = :none
3434

3535
"""
36-
affinityprop(S::DenseMatrix; [maxiter=200], [tol=1e-6], [damp=0.5],
36+
affinityprop(S::AbstractMatrix; [maxiter=200], [tol=1e-6], [damp=0.5],
3737
[display=:none]) -> AffinityPropResult
3838
3939
Perform affinity propagation clustering based on a similarity matrix `S`.
@@ -52,7 +52,7 @@ of the ``i``-th point as an *exemplar*.
5252
> Brendan J. Frey and Delbert Dueck. *Clustering by Passing Messages
5353
> Between Data Points.* Science, vol 315, pages 972-976, 2007.
5454
"""
55-
function affinityprop(S::DenseMatrix{T};
55+
function affinityprop(S::AbstractMatrix{T};
5656
maxiter::Integer=_afp_default_maxiter,
5757
tol::Real=_afp_default_tol,
5858
damp::Real=_afp_default_damp,
@@ -72,7 +72,7 @@ end
7272

7373
#### Implementation
7474

75-
function _affinityprop(S::DenseMatrix{T},
75+
function _affinityprop(S::AbstractMatrix{T},
7676
maxiter::Int,
7777
tol::Real,
7878
damp::T,
@@ -134,7 +134,7 @@ end
134134

135135

136136
# compute responsibilities
137-
function _afp_compute_r!(R::Matrix{T}, S::DenseMatrix{T}, A::Matrix{T}) where T
137+
function _afp_compute_r!(R::Matrix{T}, S::AbstractMatrix{T}, A::Matrix{T}) where T
138138
n = size(S, 1)
139139

140140
I1 = Vector{Int}(undef, n) # I1[i] is the column index of the maximum element in (A+S)[i,:]
@@ -245,7 +245,7 @@ function _afp_extract_exemplars(A::Matrix, R::Matrix)
245245
end
246246

247247
# get assignments
248-
function _afp_get_assignments(S::DenseMatrix, exemplars::Vector{Int})
248+
function _afp_get_assignments(S::AbstractMatrix, exemplars::Vector{Int})
249249
n = size(S, 1)
250250
k = length(exemplars)
251251
Se = S[:, exemplars]

src/dbscan.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
## main algorithm
4444

4545
"""
46-
dbscan(D::DenseMatrix, eps::Real, minpts::Int) -> DbscanResult
46+
dbscan(D::AbstractMatrix, eps::Real, minpts::Int) -> DbscanResult
4747
4848
Perform DBSCAN algorithm using the distance matrix `D`.
4949
@@ -54,7 +54,7 @@ The following options control which points would be considered
5454
- `minpts::Int`: the minimum number of neighboring points (including itself)
5555
to qualify a point as a density point.
5656
"""
57-
function dbscan(D::DenseMatrix{T}, eps::Real, minpts::Int) where T<:Real
57+
function dbscan(D::AbstractMatrix{T}, eps::Real, minpts::Int) where T<:Real
5858
# check arguments
5959
n = size(D, 1)
6060
size(D, 2) == n || throw(ArgumentError("D must be a square matrix ($(size(D)) given)."))
@@ -66,7 +66,7 @@ function dbscan(D::DenseMatrix{T}, eps::Real, minpts::Int) where T<:Real
6666
_dbscan(D, convert(T, eps), minpts, 1:n)
6767
end
6868

69-
function _dbscan(D::DenseMatrix{T}, eps::T, minpts::Int, visitseq::AbstractVector{Int}) where T<:Real
69+
function _dbscan(D::AbstractMatrix{T}, eps::T, minpts::Int, visitseq::AbstractVector{Int}) where T<:Real
7070
n = size(D, 1)
7171

7272
# prepare
@@ -96,7 +96,7 @@ end
9696

9797
## key steps
9898

99-
function _dbs_region_query(D::DenseMatrix{T}, p::Int, eps::T) where T<:Real
99+
function _dbs_region_query(D::AbstractMatrix{T}, p::Int, eps::T) where T<:Real
100100
n = size(D,1)
101101
nbs = Int[]
102102
dists = view(D,:,p)
@@ -108,7 +108,7 @@ function _dbs_region_query(D::DenseMatrix{T}, p::Int, eps::T) where T<:Real
108108
return nbs::Vector{Int}
109109
end
110110

111-
function _dbs_expand_cluster!(D::DenseMatrix{T}, # distance matrix
111+
function _dbs_expand_cluster!(D::AbstractMatrix{T}, # distance matrix
112112
k::Int, # the index of current cluster
113113
p::Int, # the index of seeding point
114114
nbs::Vector{Int}, # eps-neighborhood of p

src/kmedoids.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ const _kmed_default_tol = 1.0e-8
3838
const _kmed_default_display = :none
3939

4040
"""
41-
kmedoids(dist::DenseMatrix, k::Integer; ...) -> KmedoidsResult
41+
kmedoids(dist::AbstractMatrix, k::Integer; ...) -> KmedoidsResult
4242
4343
Perform K-medoids clustering of ``n`` points into `k` clusters,
4444
given the `dist` matrix (``n×n``, `dist[i, j]` is the distance
@@ -59,7 +59,7 @@ The function implements a *K-means style* algorithm instead of *PAM*
5959
iterations, but was shown to produce worse (10-20% higher total costs) results
6060
(see e.g. [Schubert & Rousseeuw (2019)](@ref kmedoid_refs)).
6161
"""
62-
function kmedoids(dist::DenseMatrix{T}, k::Integer;
62+
function kmedoids(dist::AbstractMatrix{T}, k::Integer;
6363
init=_kmed_default_init,
6464
maxiter::Integer=_kmed_default_maxiter,
6565
tol::Real=_kmed_default_tol,
@@ -79,7 +79,7 @@ function kmedoids(dist::DenseMatrix{T}, k::Integer;
7979
end
8080

8181
"""
82-
kmedoids!(dist::DenseMatrix, medoids::Vector{Int};
82+
kmedoids!(dist::AbstractMatrix, medoids::Vector{Int};
8383
[kwargs...]) -> KmedoidsResult
8484
8585
Update the current cluster `medoids` using the `dist` matrix.
@@ -89,7 +89,7 @@ as `medoids` argument.
8989
9090
See [`kmedoids`](@ref) for the description of optional `kwargs`.
9191
"""
92-
function kmedoids!(dist::DenseMatrix{T}, medoids::Vector{Int};
92+
function kmedoids!(dist::AbstractMatrix{T}, medoids::Vector{Int};
9393
maxiter::Integer=_kmed_default_maxiter,
9494
tol::Real=_kmed_default_tol,
9595
display::Symbol=_kmed_default_display) where T<:Real
@@ -110,7 +110,7 @@ end
110110
#### core algorithm
111111

112112
function _kmedoids!(medoids::Vector{Int}, # initialized medoids
113-
dist::DenseMatrix{T}, # distance matrix
113+
dist::AbstractMatrix{T}, # distance matrix
114114
maxiter::Int, # maximum number of iterations
115115
tol::Real, # tolerable change of objective
116116
displevel::Int) where T<:Real # level of display
@@ -182,7 +182,7 @@ end
182182

183183

184184
# update assignments and related quantities
185-
function _kmed_update_assignments!(dist::DenseMatrix{T}, # in: (n, n)
185+
function _kmed_update_assignments!(dist::AbstractMatrix{T}, # in: (n, n)
186186
medoids::AbstractVector{Int}, # in: (k,)
187187
assignments::Vector{Int}, # out: (n,)
188188
groups::Vector{Vector{Int}}, # out: (k,)
@@ -233,7 +233,7 @@ end
233233
# find medoid for a given group
234234
#
235235
# TODO: faster way without creating temporary arrays
236-
function _find_medoid(dist::DenseMatrix, grp::Vector{Int})
236+
function _find_medoid(dist::AbstractMatrix, grp::Vector{Int})
237237
@assert !isempty(grp)
238238
p = argmin(sum(dist[grp, grp], dims=2))
239239
return grp[p]::Int

test/affprop.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Clustering
66
using LinearAlgebra
77
using Random
88
using Statistics
9+
include("test_helpers.jl")
910

1011
@testset "affinityprop() (affinity propagation)" begin
1112
@testset "Argument checks" begin
@@ -46,6 +47,13 @@ using Statistics
4647
@test R.counts[i] == count(==(i), R.assignments)
4748
end
4849

50+
@testset "Support for arrays other than Matrix{T}" begin
51+
@testset "$(typeof(M))" for M in equivalent_matrices(S)
52+
R2 = affinityprop(M)
53+
@test R2.assignments == R.assignments
54+
end
55+
end
56+
4957
#= compare with python result
5058
the reference assignments were computed using python sklearn:
5159
```julia

test/dbscan.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22
using Clustering
33
using Distances
4+
include("test_helpers.jl")
45

56
@testset "dbscan() (DBSCAN clustering)" begin
67

@@ -36,6 +37,13 @@ for c = 1:k
3637
end
3738
@test all(R.counts .>= 180)
3839

40+
@testset "Support for arrays other than Matrix{T}" begin
41+
@testset "$(typeof(M))" for M in equivalent_matrices(D)
42+
R2 = dbscan(M, 1.0, 10) # run on complete subarray
43+
@test R2.assignments == R.assignments
44+
end
45+
end
46+
3947
@testset "normal points" begin
4048
Random.seed!(0)
4149
p0 = randn(3, 1000)
@@ -65,5 +73,4 @@ end
6573
end
6674
end
6775

68-
6976
end

test/kmedoids.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test
2-
32
using Distances
43
using Clustering
4+
include("test_helpers.jl")
55

66
@testset "kmedoids() (k-medoids)" begin
77

@@ -27,6 +27,7 @@ X = rand(d, n)
2727
dist = pairwise(SqEuclidean(), X, dims=2)
2828
@assert size(dist) == (n, n)
2929

30+
Random.seed!(34568) # reset seed again to known state
3031
R = kmedoids(dist, k)
3132
@test isa(R, KmedoidsResult)
3233
@test nclusters(R) == k
@@ -39,6 +40,14 @@ R = kmedoids(dist, k)
3940
@test isapprox(sum(R.costs), R.totalcost)
4041
@test R.converged
4142

43+
@testset "Support for arrays other than Matrix{T}" begin
44+
@testset "$(typeof(M))" for M in equivalent_matrices(dist)
45+
Random.seed!(34568) # restore seed as kmedoids is not determantistic
46+
R2 = kmedoids(M, k)
47+
@test R2.assignments == R.assignments
48+
end
49+
end
50+
4251
# k=1 and k=n cases
4352
x = pairwise(SqEuclidean(), [1 2 3; .1 .2 .3; 4 5.6 7], dims=2)
4453
kmed1 = kmedoids(x, 1)

test/test_helpers.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using LinearAlgebra
2+
using SparseArrays
3+
4+
"""
5+
equivalent_matrices(x::AbstractMatrix)
6+
7+
Returns a collection of matrixes that are equal to the input `x`, but of a different type.
8+
Useful for testing if things still work on different types of matrix.
9+
"""
10+
function equivalent_matrices(x::AbstractMatrix)
11+
mats = [
12+
Base.PermutedDimsArray(x, (1,2)), # identity permutation
13+
view(x, :, :),
14+
view(x, collect.(axes(x))...), # breaks `strides`
15+
sparse(x),
16+
]
17+
if issymmetric(x)
18+
append!(mats, [
19+
Symmetric(x),
20+
Transpose(x),
21+
])
22+
end
23+
return mats
24+
end

0 commit comments

Comments
 (0)