Skip to content

Commit 7b66a9f

Browse files
committed
Make accept AbstractMatrix
1 parent 86bbc56 commit 7b66a9f

File tree

6 files changed

+34
-18
lines changed

6 files changed

+34
-18
lines changed

src/affprop.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const _afp_default_tol = 1.0e-6
3333
const _afp_default_display = :none
3434

3535
"""
36-
affinityprop(S::DenseMatrix; [maxiter=200], [tol=1e-6], [damp=0.5],
36+
affinityprop(S::AbstractMatrix; [maxiter=200], [tol=1e-6], [damp=0.5],
3737
[display=:none]) -> AffinityPropResult
3838
3939
Perform affinity propagation clustering based on a similarity matrix `S`.
@@ -52,7 +52,7 @@ of the ``i``-th point as an *exemplar*.
5252
> Brendan J. Frey and Delbert Dueck. *Clustering by Passing Messages
5353
> Between Data Points.* Science, vol 315, pages 972-976, 2007.
5454
"""
55-
function affinityprop(S::DenseMatrix{T};
55+
function affinityprop(S::AbstractMatrix{T};
5656
maxiter::Integer=_afp_default_maxiter,
5757
tol::Real=_afp_default_tol,
5858
damp::Real=_afp_default_damp,
@@ -72,7 +72,7 @@ end
7272

7373
#### Implementation
7474

75-
function _affinityprop(S::DenseMatrix{T},
75+
function _affinityprop(S::AbstractMatrix{T},
7676
maxiter::Int,
7777
tol::Real,
7878
damp::T,
@@ -134,7 +134,7 @@ end
134134

135135

136136
# compute responsibilities
137-
function _afp_compute_r!(R::Matrix{T}, S::DenseMatrix{T}, A::Matrix{T}) where T
137+
function _afp_compute_r!(R::Matrix{T}, S::AbstractMatrix{T}, A::Matrix{T}) where T
138138
n = size(S, 1)
139139

140140
I1 = Vector{Int}(undef, n) # I1[i] is the column index of the maximum element in (A+S)[i,:]
@@ -245,7 +245,7 @@ function _afp_extract_exemplars(A::Matrix, R::Matrix)
245245
end
246246

247247
# get assignments
248-
function _afp_get_assignments(S::DenseMatrix, exemplars::Vector{Int})
248+
function _afp_get_assignments(S::AbstractMatrix, exemplars::Vector{Int})
249249
n = size(S, 1)
250250
k = length(exemplars)
251251
Se = S[:, exemplars]

src/dbscan.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
## main algorithm
4444

4545
"""
46-
dbscan(D::DenseMatrix, eps::Real, minpts::Int) -> DbscanResult
46+
dbscan(D::AbstractMatrix, eps::Real, minpts::Int) -> DbscanResult
4747
4848
Perform DBSCAN algorithm using the distance matrix `D`.
4949
@@ -54,7 +54,7 @@ The following options control which points would be considered
5454
- `minpts::Int`: the minimum number of neighboring points (including itself)
5555
to qualify a point as a density point.
5656
"""
57-
function dbscan(D::DenseMatrix{T}, eps::Real, minpts::Int) where T<:Real
57+
function dbscan(D::AbstractMatrix{T}, eps::Real, minpts::Int) where T<:Real
5858
# check arguments
5959
n = size(D, 1)
6060
size(D, 2) == n || throw(ArgumentError("D must be a square matrix ($(size(D)) given)."))
@@ -66,7 +66,7 @@ function dbscan(D::DenseMatrix{T}, eps::Real, minpts::Int) where T<:Real
6666
_dbscan(D, convert(T, eps), minpts, 1:n)
6767
end
6868

69-
function _dbscan(D::DenseMatrix{T}, eps::T, minpts::Int, visitseq::AbstractVector{Int}) where T<:Real
69+
function _dbscan(D::AbstractMatrix{T}, eps::T, minpts::Int, visitseq::AbstractVector{Int}) where T<:Real
7070
n = size(D, 1)
7171

7272
# prepare
@@ -96,7 +96,7 @@ end
9696

9797
## key steps
9898

99-
function _dbs_region_query(D::DenseMatrix{T}, p::Int, eps::T) where T<:Real
99+
function _dbs_region_query(D::AbstractMatrix{T}, p::Int, eps::T) where T<:Real
100100
n = size(D,1)
101101
nbs = Int[]
102102
dists = view(D,:,p)
@@ -108,7 +108,7 @@ function _dbs_region_query(D::DenseMatrix{T}, p::Int, eps::T) where T<:Real
108108
return nbs::Vector{Int}
109109
end
110110

111-
function _dbs_expand_cluster!(D::DenseMatrix{T}, # distance matrix
111+
function _dbs_expand_cluster!(D::AbstractMatrix{T}, # distance matrix
112112
k::Int, # the index of current cluster
113113
p::Int, # the index of seeding point
114114
nbs::Vector{Int}, # eps-neighborhood of p

src/kmedoids.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ const _kmed_default_tol = 1.0e-8
3838
const _kmed_default_display = :none
3939

4040
"""
41-
kmedoids(dist::DenseMatrix, k::Integer; ...) -> KmedoidsResult
41+
kmedoids(dist::AbstractMatrix, k::Integer; ...) -> KmedoidsResult
4242
4343
Perform K-medoids clustering of ``n`` points into `k` clusters,
4444
given the `dist` matrix (``n×n``, `dist[i, j]` is the distance
@@ -59,7 +59,7 @@ The function implements a *K-means style* algorithm instead of *PAM*
5959
iterations, but was shown to produce worse (10-20% higher total costs) results
6060
(see e.g. [Schubert & Rousseeuw (2019)](@ref kmedoid_refs)).
6161
"""
62-
function kmedoids(dist::DenseMatrix{T}, k::Integer;
62+
function kmedoids(dist::AbstractMatrix{T}, k::Integer;
6363
init=_kmed_default_init,
6464
maxiter::Integer=_kmed_default_maxiter,
6565
tol::Real=_kmed_default_tol,
@@ -79,7 +79,7 @@ function kmedoids(dist::DenseMatrix{T}, k::Integer;
7979
end
8080

8181
"""
82-
kmedoids!(dist::DenseMatrix, medoids::Vector{Int};
82+
kmedoids!(dist::AbstractMatrix, medoids::Vector{Int};
8383
[kwargs...]) -> KmedoidsResult
8484
8585
Update the current cluster `medoids` using the `dist` matrix.
@@ -89,7 +89,7 @@ as `medoids` argument.
8989
9090
See [`kmedoids`](@ref) for the description of optional `kwargs`.
9191
"""
92-
function kmedoids!(dist::DenseMatrix{T}, medoids::Vector{Int};
92+
function kmedoids!(dist::AbstractMatrix{T}, medoids::Vector{Int};
9393
maxiter::Integer=_kmed_default_maxiter,
9494
tol::Real=_kmed_default_tol,
9595
display::Symbol=_kmed_default_display) where T<:Real
@@ -110,7 +110,7 @@ end
110110
#### core algorithm
111111

112112
function _kmedoids!(medoids::Vector{Int}, # initialized medoids
113-
dist::DenseMatrix{T}, # distance matrix
113+
dist::AbstractMatrix{T}, # distance matrix
114114
maxiter::Int, # maximum number of iterations
115115
tol::Real, # tolerable change of objective
116116
displevel::Int) where T<:Real # level of display
@@ -182,7 +182,7 @@ end
182182

183183

184184
# update assignments and related quantities
185-
function _kmed_update_assignments!(dist::DenseMatrix{T}, # in: (n, n)
185+
function _kmed_update_assignments!(dist::AbstractMatrix{T}, # in: (n, n)
186186
medoids::AbstractVector{Int}, # in: (k,)
187187
assignments::Vector{Int}, # out: (n,)
188188
groups::Vector{Vector{Int}}, # out: (k,)
@@ -233,7 +233,7 @@ end
233233
# find medoid for a given group
234234
#
235235
# TODO: faster way without creating temporary arrays
236-
function _find_medoid(dist::DenseMatrix, grp::Vector{Int})
236+
function _find_medoid(dist::AbstractMatrix, grp::Vector{Int})
237237
@assert !isempty(grp)
238238
p = argmin(sum(dist[grp, grp], dims=2))
239239
return grp[p]::Int

test/affprop.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ using Statistics
4646
@test R.counts[i] == count(==(i), R.assignments)
4747
end
4848

49+
@testset "Ensure works on nonbasic array type (SubArray)" begin
50+
RR = affinityprop(@view S[:,:]) # run on complete subarray
51+
@test RR.assignments == R.assignments
52+
end
53+
4954
#= compare with python result
5055
the reference assignments were computed using python sklearn:
5156
```julia

test/dbscan.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ for c = 1:k
3636
end
3737
@test all(R.counts .>= 180)
3838

39+
@testset "Ensure works on nonbasic array type (SubArray)" begin
40+
RR = dbscan(@view(D[:,:]), 1.0, 10) # run on complete subarray
41+
@test RR.assignments == R.assignments
42+
end
43+
3944
@testset "normal points" begin
4045
Random.seed!(0)
4146
p0 = randn(3, 1000)
@@ -65,5 +70,4 @@ end
6570
end
6671
end
6772

68-
6973
end

test/kmedoids.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ X = rand(d, n)
2727
dist = pairwise(SqEuclidean(), X, dims=2)
2828
@assert size(dist) == (n, n)
2929

30+
Random.seed!(34568) # reset seed again to known state
3031
R = kmedoids(dist, k)
3132
@test isa(R, KmedoidsResult)
3233
@test nclusters(R) == k
@@ -39,6 +40,12 @@ R = kmedoids(dist, k)
3940
@test isapprox(sum(R.costs), R.totalcost)
4041
@test R.converged
4142

43+
@testset "Ensure works on nonbasic array type (SubArray)" begin
44+
Random.seed!(34568) # restore seed as kmedoids is not determantistic
45+
RR = kmedoids(@view(dist[:,:]), k) # run on complete subarray
46+
@test RR.assignments == R.assignments
47+
end
48+
4249
# k=1 and k=n cases
4350
x = pairwise(SqEuclidean(), [1 2 3; .1 .2 .3; 4 5.6 7], dims=2)
4451
kmed1 = kmedoids(x, 1)

0 commit comments

Comments
 (0)