Skip to content

Commit 01ea380

Browse files
authored
Add rng parameter to appropriate methods (#228)
* add `Random` as dependency * added `rng` parameter to `fuzzy_cmeans` * added `rng` parameter to `kmeans` * switch to StableRNGs in `hclust` tests * use old kwargs syntax * make k-means result comparison more robust * generated new test data for `affprop` test
1 parent 0b6f7b8 commit 01ea380

File tree

9 files changed

+174
-137
lines changed

9 files changed

+174
-137
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
13+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314

1415
[compat]
1516
Distances = "0.8, 0.9, 0.10"

src/Clustering.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module Clustering
77
using LinearAlgebra
88
using SparseArrays
99
using Statistics
10+
using Random
1011

1112
import Base: show
1213
import StatsBase: IntegerVector, RealVector, RealMatrix, counts

src/fuzzycmeans.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The output of [`fuzzy_cmeans`](@ref) function.
1919
struct FuzzyCMeansResult{T<:AbstractFloat}
2020
centers::Matrix{T} # cluster centers (d x C)
2121
weights::Matrix{Float64} # assigned weights (n x C)
22-
iterations::Int # number of elasped iterations
22+
iterations::Int # number of elapsed iterations
2323
converged::Bool # whether the procedure converged
2424
end
2525

@@ -88,14 +88,15 @@ function fuzzy_cmeans(
8888
maxiter::Int = _fcmeans_default_maxiter,
8989
tol::Real = _fcmeans_default_tol,
9090
dist_metric::Metric = Euclidean(),
91-
display::Symbol = _fcmeans_default_display
91+
display::Symbol = _fcmeans_default_display,
92+
rng::AbstractRNG = Random.GLOBAL_RNG
9293
) where T<:Real
9394

9495
nrows, ncols = size(data)
9596
2 <= C < ncols || throw(ArgumentError("C must have 2 <= C < n=$ncols ($C given)"))
9697
1 < fuzziness || throw(ArgumentError("fuzziness must be greater than 1 ($fuzziness given)"))
9798

98-
_fuzzy_cmeans(data, C, fuzziness, maxiter, tol, dist_metric, display_level(display))
99+
_fuzzy_cmeans(data, C, fuzziness, maxiter, tol, dist_metric, display_level(display),rng)
99100

100101
end
101102

@@ -108,13 +109,14 @@ function _fuzzy_cmeans(
108109
maxiter::Int, # maximum number of iterations
109110
tol::Real, # tolerance
110111
dist_metric::Metric, # metric to calculate distance
111-
displevel::Int # the level of display
112+
displevel::Int, # the level of display
113+
rng::AbstractRNG # RNG object
112114
) where T<:Real
113115

114116
nrows, ncols = size(data)
115117

116118
# Initialize weights randomly
117-
weights = rand(Float64, ncols, C)
119+
weights = rand(rng, Float64, ncols, C)
118120
weights ./= sum(weights, dims=2)
119121

120122
centers = zeros(T, (nrows, C))

src/kmeans.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ function kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
4949
maxiter::Integer=_kmeans_default_maxiter, # in: maximum number of iterations
5050
tol::Real=_kmeans_default_tol, # in: tolerance of change at convergence
5151
display::Symbol=_kmeans_default_display, # in: level of display
52-
distance::SemiMetric=SqEuclidean()) # in: function to compute distances
52+
distance::SemiMetric=SqEuclidean(), # in: function to compute distances
53+
rng::AbstractRNG=Random.GLOBAL_RNG) # in: RNG object
5354
d, n = size(X)
5455
dc, k = size(centers)
5556
WC = (weights === nothing) ? Int : eltype(weights)
@@ -68,7 +69,7 @@ function kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
6869
mean!(centers, X)
6970
end
7071
return _kmeans!(X, weights, centers, Int(maxiter), Float64(tol),
71-
display_level(display), distance)
72+
display_level(display), distance, rng)
7273
end
7374
end
7475

@@ -99,19 +100,20 @@ function kmeans(X::AbstractMatrix{<:Real}, # in: data matrix (d x
99100
maxiter::Integer=_kmeans_default_maxiter, # in: maximum number of iterations
100101
tol::Real=_kmeans_default_tol, # in: tolerance of change at convergence
101102
display::Symbol=_kmeans_default_display, # in: level of display
102-
distance::SemiMetric=SqEuclidean()) # in: function to calculate distance with
103+
distance::SemiMetric=SqEuclidean(), # in: function to calculate distance with
104+
rng::AbstractRNG=Random.GLOBAL_RNG) # in: RNG object
103105
d, n = size(X)
104106
(1 <= k <= n) || throw(ArgumentError("k must be from 1:n (n=$n), k=$k given."))
105107

106108
# initialize the centers using a type wide enough so that the updates
107109
# centers[i, cj] += X[i, j] * wj will occur without loss of precision through rounding
108110
T = float(weights === nothing ? eltype(X) : promote_type(eltype(X), eltype(weights)))
109-
iseeds = initseeds(init, X, k)
111+
iseeds = initseeds(init, X, k, rng=rng)
110112
centers = copyseeds!(Matrix{T}(undef, d, k), X, iseeds)
111113

112114
kmeans!(X, centers;
113115
weights=weights, maxiter=Int(maxiter), tol=Float64(tol),
114-
display=display, distance=distance)
116+
display=display, distance=distance, rng=rng)
115117
end
116118

117119
#### Core implementation
@@ -123,7 +125,8 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
123125
maxiter::Int, # in: maximum number of iterations
124126
tol::Float64, # in: tolerance of change at convergence
125127
displevel::Int, # in: the level of display
126-
distance::SemiMetric) # in: function to calculate distance
128+
distance::SemiMetric, # in: function to calculate distance
129+
rng::AbstractRNG) # in: RNG object
127130
d, n = size(X)
128131
k = size(centers, 2)
129132
to_update = Vector{Bool}(undef, k) # whether a center needs to be updated
@@ -161,7 +164,7 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
161164
update_centers!(X, weights, assignments, to_update, centers, wcounts)
162165

163166
if !isempty(unused)
164-
repick_unused_centers(X, costs, centers, unused, distance)
167+
repick_unused_centers(X, costs, centers, unused, distance, rng)
165168
to_update[unused] .= true
166169
end
167170

@@ -372,14 +375,15 @@ function repick_unused_centers(X::AbstractMatrix{<:Real}, # in: the data matrix
372375
costs::Vector{<:Real}, # in: the current assignment costs (n)
373376
centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k)
374377
unused::Vector{Int}, # in: indices of centers to be updated
375-
distance::SemiMetric) # in: function to calculate the distance with
378+
distance::SemiMetric, # in: function to calculate the distance with
379+
rng::AbstractRNG) # in: RNG object
376380
# pick new centers using a scheme like kmeans++
377381
ds = similar(costs)
378382
tcosts = copy(costs)
379383
n = size(X, 2)
380384

381385
for i in unused
382-
j = wsample(1:n, tcosts)
386+
j = wsample(rng, 1:n, tcosts)
383387
tcosts[j] = 0
384388
v = view(X, :, j)
385389
centers[:, i] = v

src/seeding.jl

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
# Let alg be an instance of such an algorithm, then it should
66
# support the following usage:
77
#
8-
# initseeds!(iseeds, alg, X)
9-
# initseeds_by_costs!(iseeds, alg, costs)
8+
# initseeds!(iseeds, alg, X; kwargs...)
9+
# initseeds_by_costs!(iseeds, alg, costs; kwargs...)
1010
#
1111
# Here:
1212
# - iseeds: a vector of resultant indexes of the chosen seeds
1313
# - alg: the seeding algorithm instance
1414
# - X: the data matrix, each column being a data point
1515
# - costs: pre-computed pairwise cost matrix.
16+
# - kwargs: additional kw-arguments, i.e. `rng`
1617
#
1718
# This function returns iseeds
1819
#
@@ -39,8 +40,8 @@ name of the algorithm.
3940
4041
Returns the vector of `k` seed indices.
4142
"""
42-
initseeds(alg::SeedingAlgorithm, X::AbstractMatrix{<:Real}, k::Integer) =
43-
initseeds!(Vector{Int}(undef, k), alg, X)
43+
initseeds(alg::SeedingAlgorithm, X::AbstractMatrix{<:Real}, k::Integer; kwargs...) =
44+
initseeds!(Vector{Int}(undef, k), alg, X; kwargs...)
4445

4546
"""
4647
initseeds_by_costs(alg::Union{SeedingAlgorithm, Symbol},
@@ -54,8 +55,8 @@ between the points as the cost.
5455
5556
Returns the vector of `k` seed indices.
5657
"""
57-
initseeds_by_costs(alg::SeedingAlgorithm, costs::AbstractMatrix{<:Real}, k::Integer) =
58-
initseeds_by_costs!(Vector{Int}(undef, k), alg, costs)
58+
initseeds_by_costs(alg::SeedingAlgorithm, costs::AbstractMatrix{<:Real}, k::Integer; kwargs...) =
59+
initseeds_by_costs!(Vector{Int}(undef, k), alg, costs; kwargs...)
5960

6061
seeding_algorithm(s::Symbol) =
6162
s == :rand ? RandSeedAlg() :
@@ -71,14 +72,14 @@ end
7172
check_seeding_args(X::AbstractMatrix, iseeds::AbstractVector) =
7273
check_seeding_args(size(X, 2), length(iseeds))
7374

74-
initseeds(algname::Symbol, X::AbstractMatrix{<:Real}, k::Integer) =
75-
initseeds(seeding_algorithm(algname), X, k)::Vector{Int}
75+
initseeds(algname::Symbol, X::AbstractMatrix{<:Real}, k::Integer; kwargs...) =
76+
initseeds(seeding_algorithm(algname), X, k; kwargs...)::Vector{Int}
7677

77-
initseeds_by_costs(algname::Symbol, costs::AbstractMatrix{<:Real}, k::Integer) =
78-
initseeds_by_costs(seeding_algorithm(algname), costs, k)
78+
initseeds_by_costs(algname::Symbol, costs::AbstractMatrix{<:Real}, k::Integer; kwargs...) =
79+
initseeds_by_costs(seeding_algorithm(algname), costs, k; kwargs...)
7980

8081
# use specified vector of seeds
81-
function initseeds(iseeds::AbstractVector{<:Integer}, X::AbstractMatrix{<:Real}, k::Integer)
82+
function initseeds(iseeds::AbstractVector{<:Integer}, X::AbstractMatrix{<:Real}, k::Integer; kwargs...)
8283
length(iseeds) == k ||
8384
throw(ArgumentError("The length of seeds vector ($(length(iseeds))) differs from the number of seeds requested ($k)"))
8485
check_seeding_args(X, iseeds)
@@ -90,8 +91,8 @@ function initseeds(iseeds::AbstractVector{<:Integer}, X::AbstractMatrix{<:Real},
9091
# NOTE no duplicate checks are done, should we?
9192
convert(Vector{Int}, iseeds)
9293
end
93-
initseeds_by_costs(iseeds::AbstractVector{<:Integer}, costs::AbstractMatrix{<:Real}, k::Integer) =
94-
initseeds(iseeds, costs, k) # NOTE: passing costs as X, but should be fine since only size(X, 2) is used
94+
initseeds_by_costs(iseeds::AbstractVector{<:Integer}, costs::AbstractMatrix{<:Real}, k::Integer; kwargs...) =
95+
initseeds(iseeds, costs, k; kwargs...) # NOTE: passing costs as X, but should be fine since only size(X, 2) is used
9596

9697
function copyseeds!(S::AbstractMatrix{<:AbstractFloat},
9798
X::AbstractMatrix{<:Real},
@@ -119,9 +120,10 @@ struct RandSeedAlg <: SeedingAlgorithm end
119120
Initialize `iseeds` with the indices of cluster seeds for the `X` data matrix
120121
using the `alg` seeding algorithm.
121122
"""
122-
function initseeds!(iseeds::IntegerVector, alg::RandSeedAlg, X::AbstractMatrix{<:Real})
123+
function initseeds!(iseeds::IntegerVector, alg::RandSeedAlg, X::AbstractMatrix{<:Real};
124+
rng::AbstractRNG=Random.GLOBAL_RNG)
123125
check_seeding_args(X, iseeds)
124-
sample!(1:size(X, 2), iseeds; replace=false)
126+
sample!(rng, 1:size(X, 2), iseeds; replace=false)
125127
end
126128

127129
"""
@@ -135,9 +137,9 @@ Here, `costs[i, j]` is the cost of assigning points ``i`` and ``j``
135137
to the same cluster. One may, for example, use the squared Euclidean distance
136138
between the points as the cost.
137139
"""
138-
function initseeds_by_costs!(iseeds::IntegerVector, alg::RandSeedAlg, X::AbstractMatrix{<:Real})
140+
function initseeds_by_costs!(iseeds::IntegerVector, alg::RandSeedAlg, X::AbstractMatrix{<:Real}; rng::AbstractRNG=Random.GLOBAL_RNG)
139141
check_seeding_args(X, iseeds)
140-
sample!(1:size(X,2), iseeds; replace=false)
142+
sample!(rng, 1:size(X,2), iseeds; replace=false)
141143
end
142144

143145
"""
@@ -157,13 +159,14 @@ struct KmppAlg <: SeedingAlgorithm end
157159

158160
function initseeds!(iseeds::IntegerVector, alg::KmppAlg,
159161
X::AbstractMatrix{<:Real},
160-
metric::PreMetric = SqEuclidean())
162+
metric::PreMetric = SqEuclidean();
163+
rng::AbstractRNG=Random.GLOBAL_RNG)
161164
n = size(X, 2)
162165
k = length(iseeds)
163166
check_seeding_args(n, k)
164167

165168
# randomly pick the first center
166-
p = rand(1:n)
169+
p = rand(rng, 1:n)
167170
iseeds[1] = p
168171

169172
if k > 1
@@ -173,7 +176,7 @@ function initseeds!(iseeds::IntegerVector, alg::KmppAlg,
173176
# pick remaining (with a chance proportional to mincosts)
174177
tmpcosts = zeros(n)
175178
for j = 2:k
176-
p = wsample(1:n, mincosts)
179+
p = wsample(rng, 1:n, mincosts)
177180
iseeds[j] = p
178181

179182
# update mincosts
@@ -188,13 +191,14 @@ function initseeds!(iseeds::IntegerVector, alg::KmppAlg,
188191
end
189192

190193
function initseeds_by_costs!(iseeds::IntegerVector, alg::KmppAlg,
191-
costs::AbstractMatrix{<:Real})
194+
costs::AbstractMatrix{<:Real};
195+
rng::AbstractRNG=Random.GLOBAL_RNG)
192196
n = size(costs, 1)
193197
k = length(iseeds)
194198
check_seeding_args(n, k)
195199

196200
# randomly pick the first center
197-
p = rand(1:n)
201+
p = rand(rng, 1:n)
198202
iseeds[1] = p
199203

200204
if k > 1
@@ -203,7 +207,7 @@ function initseeds_by_costs!(iseeds::IntegerVector, alg::KmppAlg,
203207

204208
# pick remaining (with a chance proportional to mincosts)
205209
for j = 2:k
206-
p = wsample(1:n, mincosts)
210+
p = wsample(rng, 1:n, mincosts)
207211
iseeds[j] = p
208212

209213
# update mincosts
@@ -230,7 +234,7 @@ Choose the ``k`` points with the highest *centrality* as seeds.
230234
struct KmCentralityAlg <: SeedingAlgorithm end
231235

232236
function initseeds_by_costs!(iseeds::IntegerVector, alg::KmCentralityAlg,
233-
costs::AbstractMatrix{<:Real})
237+
costs::AbstractMatrix{<:Real}; kwargs...)
234238

235239
n = size(costs, 1)
236240
k = length(iseeds)
@@ -255,5 +259,5 @@ function initseeds_by_costs!(iseeds::IntegerVector, alg::KmCentralityAlg,
255259
end
256260

257261
initseeds!(iseeds::IntegerVector, alg::KmCentralityAlg, X::AbstractMatrix{<:Real},
258-
metric::PreMetric = SqEuclidean()) =
259-
initseeds_by_costs!(iseeds, alg, pairwise(metric, X, dims=2))
262+
metric::PreMetric = SqEuclidean(); kwargs...) =
263+
initseeds_by_costs!(iseeds, alg, pairwise(metric, X, dims=2); kwargs...)

0 commit comments

Comments
 (0)