Skip to content

Commit 10269b5

Browse files
authored
Merge pull request #3695 from janezd/kmeans-output-centroid-labels
[ENH] k-Means: Output centroid labels
2 parents 9b3ba89 + 20d04c0 commit 10269b5

File tree

2 files changed

+88
-9
lines changed

2 files changed

+88
-9
lines changed

Orange/widgets/unsupervised/owkmeans.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from Orange.clustering.kmeans import KMeansModel, SILHOUETTE_MAX_SAMPLES
1212
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
1313
from Orange.data.util import get_unique_names
14+
from Orange.preprocess.impute import ReplaceUnknowns
1415
from Orange.widgets import widget, gui
1516
from Orange.widgets.settings import Setting
1617
from Orange.widgets.utils.annotated_data import \
@@ -33,7 +34,8 @@ def __init__(self, parent=None):
3334
def rowCount(self, index=QModelIndex()):
3435
return 0 if index.isValid() else len(self.scores)
3536

36-
def columnCount(self, index=QModelIndex()):
37+
@staticmethod
38+
def columnCount(_index=QModelIndex()):
3739
return 1
3840

3941
def flags(self, index):
@@ -64,10 +66,12 @@ def data(self, index, role=Qt.DisplayRole):
6466
return score
6567
elif role == gui.BarRatioRole and valid:
6668
return score
69+
return None
6770

68-
def headerData(self, row, orientation, role=Qt.DisplayRole):
71+
def headerData(self, row, _orientation, role=Qt.DisplayRole):
6972
if role == Qt.DisplayRole:
7073
return str(row + self.start_k)
74+
return None
7175

7276

7377
class Task:
@@ -443,8 +447,9 @@ def update_results(self):
443447

444448
def selected_row(self):
445449
indices = self.table_view.selectedIndexes()
446-
if indices:
447-
return indices[0].row()
450+
if not indices:
451+
return None
452+
return indices[0].row()
448453

449454
def select_row(self):
450455
self.send_data()
@@ -468,21 +473,49 @@ def send_data(self):
468473
values=["C%d" % (x + 1) for x in range(km.k)]
469474
)
470475
clust_ids = km(self.data)
476+
clust_col = clust_ids.X.ravel()
471477
silhouette_var = ContinuousVariable(
472478
get_unique_names(domain, "Silhouette"))
473479
if km.silhouette_samples is not None:
474480
self.Warning.no_silhouettes.clear()
475481
scores = np.arctan(km.silhouette_samples) / np.pi + 0.5
482+
clust_scores = []
483+
for i in range(km.k):
484+
in_clust = clust_col == i
485+
if in_clust.any():
486+
clust_scores.append(np.mean(scores[in_clust]))
487+
else:
488+
clust_scores.append(0.)
489+
clust_scores = np.atleast_2d(clust_scores).T
476490
else:
477491
self.Warning.no_silhouettes()
478492
scores = np.nan
493+
clust_scores = np.full((km.k, 1), np.nan)
479494

480495
new_domain = add_columns(domain, metas=[cluster_var, silhouette_var])
481496
new_table = self.data.transform(new_domain)
482-
new_table.get_column_view(cluster_var)[0][:] = clust_ids.X.ravel()
497+
new_table.get_column_view(cluster_var)[0][:] = clust_col
483498
new_table.get_column_view(silhouette_var)[0][:] = scores
484499

485-
centroids = Table(Domain(km.pre_domain.attributes), km.centroids)
500+
centroid_attributes = [
501+
attr.compute_value.variable
502+
if isinstance(attr.compute_value, ReplaceUnknowns)
503+
and attr.compute_value.variable in domain.attributes
504+
else attr
505+
for attr in km.pre_domain.attributes]
506+
centroid_domain = add_columns(
507+
Domain(centroid_attributes, [], domain.metas),
508+
metas=[cluster_var, silhouette_var])
509+
centroids = Table(
510+
centroid_domain, km.centroids, None,
511+
np.hstack((np.full((km.k, len(domain.metas)), np.nan),
512+
np.arange(km.k).reshape(km.k, 1),
513+
clust_scores))
514+
)
515+
if self.data.name == Table.name:
516+
centroids.name = "centroids"
517+
else:
518+
centroids.name = f"{self.data.name} centroids"
486519

487520
self.Outputs.annotated_data.send(new_table)
488521
self.Outputs.centroids.send(centroids)

Orange/widgets/unsupervised/tests/test_owkmeans.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,60 @@ def test_data_on_output(self):
197197
# removing data should have cleared the output
198198
self.assertEqual(self.widget.data, None)
199199

200+
@patch("Orange.clustering.kmeans.KMeansModel.__call__")
201+
def test_centroids_on_output(self, km_call):
202+
ret = km_call.return_value = Mock()
203+
ret.X = np.array([0] * 50 + [1] * 100)
204+
ret.silhouette_samples = np.arange(150) / 150
205+
206+
widget = self.widget
207+
widget.optimize_k = False
208+
widget.k = 4
209+
self.send_signal(widget.Inputs.data, self.iris)
210+
self.commit_and_wait()
211+
212+
widget.clusterings[4].silhouette_samples = np.arange(150) / 150
213+
widget.send_data()
214+
out = self.get_output(widget.Outputs.centroids)
215+
np.testing.assert_almost_equal(
216+
out.metas,
217+
[[0, np.mean(np.arctan(np.arange(50) / 150)) / np.pi + 0.5],
218+
[1, np.mean(np.arctan(np.arange(50, 150) / 150)) / np.pi + 0.5],
219+
[2, 0], [3, 0]])
220+
self.assertEqual(out.name, "iris centroids")
221+
222+
def test_centroids_domain_on_output(self):
223+
widget = self.widget
224+
widget.optimize_k = False
225+
widget.k = 4
226+
heart_disease = Table("heart_disease")
227+
heart_disease.name = Table.name # untitled
228+
self.send_signal(widget.Inputs.data, heart_disease)
229+
self.commit_and_wait()
230+
231+
in_attrs = heart_disease.domain.attributes
232+
out = self.get_output(widget.Outputs.centroids)
233+
out_attrs = out.domain.attributes
234+
out_ids = {id(attr) for attr in out_attrs}
235+
for attr in in_attrs:
236+
self.assertEqual(
237+
id(attr) in out_ids, attr.is_continuous,
238+
f"at attribute '{attr.name}'"
239+
)
240+
self.assertEqual(
241+
len(out_attrs),
242+
sum(attr.is_continuous or len(attr.values) for attr in in_attrs))
243+
self.assertEqual(out.name, "centroids")
244+
200245
class KMeansFail(Orange.clustering.KMeans):
201246
fail_on = set()
202247

203-
def fit(self, *args):
248+
def fit(self, X, Y=None):
204249
# when not optimizing, params is empty?!
205250
k = self.params.get("n_clusters", 3)
206251
if k in self.fail_on:
207252
raise ValueError("k={} fails".format(k))
208-
return super().fit(*args)
253+
return super().fit(X, Y)
209254

210255
@patch("Orange.widgets.unsupervised.owkmeans.KMeans", new=KMeansFail)
211256
def test_optimization_fails(self):
@@ -346,7 +391,8 @@ def test_silhouette_column(self):
346391
widget.k = 4
347392
widget.optimize_k = False
348393

349-
random = np.random.RandomState(0) # Avoid randomness in the test
394+
# Avoid randomness in the test
395+
random = np.random.RandomState(0) # pylint: disable=no-member
350396
table = Table(random.rand(110, 2))
351397
with patch("Orange.clustering.kmeans.SILHOUETTE_MAX_SAMPLES", 100):
352398
self.send_signal(self.widget.Inputs.data, table)

0 commit comments

Comments
 (0)