Skip to content

Commit fd9e691

Browse files
committed
add code and tests
0 parents  commit fd9e691

File tree

6 files changed

+316
-0
lines changed

6 files changed

+316
-0
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Manifest.toml
2+
.ipynb_checkpoints
3+
*~
4+
#*
5+
*.bu
6+
.DS_Store
7+
sandbox/

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2020 The Alan Turing Institute
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

Project.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
name = "MLJClusteringInterface"
2+
uuid = "d354fa79-ed1c-40d4-88ef-b8c7bd1568af"
3+
authors = ["Anthony D. Blaom <[email protected]>", "Thibaut Lienart <[email protected]>", "Okon Samuel <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
8+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
9+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
10+
11+
[compat]
12+
Distances = "0.9, 0.10"
13+
MLJModelInterface = "0.3.6"
14+
Clustering = "0.14"
15+
julia = "1"
16+
17+
[extras]
18+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
20+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
22+
23+
[targets]
24+
test = ["LinearAlgebra", "MLJBase", "Random", "Test"]

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# MLJ <> Clustering.jl
2+
Repository implementing MLJ interface for
3+
[Clustering.jl](https://github.com/JuliaStats/Clustering.jl) models.
4+
5+
6+
[![Build Status](https://travis-ci.com/alan-turing-institute/MLJClusteringInterface.jl.svg?branch=master)](https://travis-ci.com/github/alan-turing-institute/MLJClusteringInterface.jl)
7+
[![Coverage](http://codecov.io/github/alan-turing-institute/MLJClusteringInterface.jl/coverage.svg?branch=master)](https://codecov.io/gh/alan-turing-institute/MLJClusteringInterface.jl)
8+
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
9+

src/MLJClusteringInterface.jl

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# NOTE: there's a `kmeans!` function that updates centers, maybe a candidate
2+
# for the `update` machinery. Same for `kmedoids!`
3+
# NOTE: if the prediction is done on the original array, just the assignment
4+
# should be returned, unclear what's the best way of doing this.
5+
6+
module MLJClusteringInterface
7+
8+
# ===================================================================
9+
# IMPORTS
10+
import Clustering
11+
import MLJModelInterface
12+
import MLJModelInterface: Continuous, Count, Finite, MulticlassTable, OrderedFactor,
13+
@mlj_model, metadata_model, metadata_pkg
14+
15+
using Distances
16+
17+
# ===================================================================
18+
## EXPORTS
19+
export KMeans, KMedoids
20+
21+
# ===================================================================
22+
## CONSTANTS
23+
# Define constants for easy referencing of packages
24+
const MMI = MLJModelInterface
25+
const Cl = Clustering
26+
27+
# Definitions of model descriptions for use in model doc-strings.
28+
const KMeansDescription ="""
29+
K-Means algorithm: find K centroids corresponding to K clusters in the data.
30+
"""
31+
32+
const KMedoidsDescription ="""
33+
K-Medoids algorithm: find K centroids corresponding to K clusters in the data.
34+
Unlike K-Means, the centroids are found among data points themselves.
35+
"""
36+
37+
const KMFields ="""
38+
## Keywords
39+
40+
* `k=3` : number of centroids
41+
* `metric` : distance metric to use
42+
"""
43+
44+
####
45+
#### KMeans
46+
####
47+
"""
48+
KMeans(; kwargs...)
49+
50+
$KMeansDescription
51+
52+
$KMFields
53+
54+
See also the
55+
[package documentation](http://juliastats.github.io/Clustering.jl/latest/kmeans.html).
56+
"""
57+
58+
@mlj_model mutable struct KMeans <: MMI.Unsupervised
59+
k::Int = 3::(_ ≥ 2)
60+
metric::SemiMetric = SqEuclidean()
61+
end
62+
63+
####
64+
#### KMeans
65+
####
66+
67+
function MMI.fit(model::KMeans, verbosity::Int, X)
68+
# NOTE: using transpose here to get a LinearAlgebra.Transpose object
69+
# which Kmeans can handle.
70+
Xarray = transpose(MMI.matrix(X))
71+
result = Cl.kmeans(Xarray, model.k; distance=model.metric)
72+
cluster_labels = MMI.categorical(1:model.k)
73+
fitresult = (result.centers, cluster_labels) # centers (p x k)
74+
cache = nothing
75+
report = (
76+
assignments=result.assignments, # size n
77+
cluster_labels=cluster_labels
78+
)
79+
return fitresult, cache, report
80+
end
81+
82+
MMI.fitted_params(::KMeans, fitresult) = (centers=fitresult[1],)
83+
84+
function MMI.transform(model::KMeans, fitresult, X)
85+
# pairwise distance from samples to centers
86+
= pairwise(
87+
model.metric,
88+
transpose(MMI.matrix(X)),
89+
fitresult[1],
90+
dims=2
91+
)
92+
return MMI.table(X̃, prototype=X)
93+
end
94+
95+
"""
96+
KMedoids(; kwargs...)
97+
98+
$KMedoidsDescription
99+
100+
$KMFields
101+
102+
See also the
103+
[package documentation](http://juliastats.github.io/Clustering.jl/latest/kmedoids.html).
104+
"""
105+
@mlj_model mutable struct KMedoids <: MMI.Unsupervised
106+
k::Int = 3::(_ ≥ 2)
107+
metric::SemiMetric = SqEuclidean()
108+
end
109+
110+
function MMI.fit(model::KMedoids, verbosity::Int, X)
111+
# NOTE: using transpose=true will materialize the transpose (~ permutedims), KMedoids
112+
# does not yet accept LinearAlgebra.Transpose
113+
Xarray = MMI.matrix(X, transpose=true)
114+
# cost matrix: all the pairwise distances
115+
cost_array = pairwise(model.metric, Xarray, dims=2) # n x n
116+
result = Cl.kmedoids(cost_array, model.k)
117+
cluster_labels = MMI.categorical(1:model.k)
118+
fitresult = (view(Xarray, :, result.medoids), cluster_labels) # medoids
119+
cache = nothing
120+
report = (
121+
assignments=result.assignments, # size n
122+
cluster_labels=cluster_labels
123+
)
124+
return fitresult, cache, report
125+
end
126+
127+
MMI.fitted_params(::KMedoids, fitresult) = (medoids=fitresult[1],)
128+
129+
function MMI.transform(model::KMedoids, fitresult, X)
130+
# pairwise distance from samples to medoids
131+
= pairwise(
132+
model.metric,
133+
MMI.matrix(X, transpose=true),
134+
fitresult[1], dims=2
135+
)
136+
return MMI.table(X̃, prototype=X)
137+
end
138+
139+
####
140+
#### Predict methods
141+
####
142+
143+
function MMI.predict(model::Union{KMeans,KMedoids}, fitresult, Xnew)
144+
locations, cluster_labels = fitresult
145+
Xarray = MMI.matrix(Xnew)
146+
(n, p), k = size(Xarray), model.k
147+
pred = zeros(Int, n)
148+
149+
@inbounds for i in 1:n
150+
minv = Inf
151+
@inbounds @simd for j in 1:k
152+
curv = evaluate(
153+
model.metric, view(Xarray, i, :), view(locations, :, j)
154+
)
155+
P = curv < minv
156+
pred[i] = j * P + pred[i] * !P # if P is true --> j
157+
minv = curv * P + minv * !P # if P is true --> curvalue
158+
end
159+
end
160+
return cluster_labels[pred]
161+
end
162+
163+
####
164+
#### METADATA
165+
####
166+
167+
metadata_pkg.(
168+
(KMeans, KMedoids),
169+
name="Clustering",
170+
uuid="aaaa29a8-35af-508c-8bc3-b662a17a0fe5",
171+
url="https://github.com/JuliaStats/Clustering.jl",
172+
julia=true,
173+
license="MIT",
174+
is_wrapper=false
175+
)
176+
177+
metadata_model(
178+
KMeans,
179+
input = MMI.Table(Continuous),
180+
output = MMI.Table(Continuous),
181+
weights = false,
182+
descr = KMeansDescription
183+
)
184+
185+
metadata_model(
186+
KMedoids,
187+
input = MMI.Table(Continuous),
188+
output = MMI.Table(Continuous),
189+
weights = false,
190+
descr = KMedoidsDescription
191+
)
192+
193+
end # module
194+

test/runtests.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import Clustering
2+
import Distances
3+
import LinearAlgebra: norm
4+
5+
using MLJBase
6+
using MLJClusteringInterface
7+
using Random:seed!
8+
using Test
9+
10+
const Dist = Distances
11+
12+
seed!(132442)
13+
X, y = @load_crabs
14+
15+
####
16+
#### KMEANS
17+
####
18+
19+
@testset "Kmeans" begin
20+
barekm = KMeans()
21+
fitresult, cache, report = fit(barekm, 1, X)
22+
R = matrix(transform(barekm, fitresult, X))
23+
X_array = matrix(X)
24+
# distance from first point to second center
25+
@test R[1, 2] norm(view(X_array, 1, :) .- view(fitresult[1], :, 2))^2
26+
@test R[10, 3] norm(view(X_array, 10, :) .- view(fitresult[1], :, 3))^2
27+
28+
p = predict(barekm, fitresult, X)
29+
@test argmin(R[1, :]) == p[1]
30+
@test argmin(R[10, :]) == p[10]
31+
32+
infos = info_dict(barekm)
33+
@test infos[:package_name] == "Clustering"
34+
@test infos[:is_pure_julia]
35+
@test infos[:package_license] == "MIT"
36+
@test infos[:input_scitype] == Table(Continuous)
37+
@test infos[:output_scitype] == Table(Continuous)
38+
infos[:docstring]
39+
end
40+
41+
####
42+
#### KMEDOIDS
43+
####
44+
45+
@testset "Kmedoids" begin
46+
barekm = KMedoids()
47+
fitresult, cache, report = fit(barekm, 1, X)
48+
X_array = matrix(X)
49+
R = matrix(transform(barekm, fitresult, X))
50+
@test R[1, 2] Dist.evaluate(
51+
barekm.metric, view(X_array, 1, :), view(fitresult[1], :, 2)
52+
)
53+
@test R[10, 3] Dist.evaluate(
54+
barekm.metric, view(X_array, 10, :), view(fitresult[1], :, 3)
55+
)
56+
p = predict(barekm, fitresult, X)
57+
@test all(report.assignments .== p)
58+
infos = info_dict(barekm)
59+
@test infos[:input_scitype] == Table(Continuous)
60+
@test infos[:output_scitype] == Table(Continuous)
61+
end

0 commit comments

Comments
 (0)