Skip to content

Commit 03b66c9

Browse files
authored
Merge pull request #3090 from biolab/fix-owkmeans-init
[FIX] owkmeans: fix initialization choice
2 parents 7c80da3 + aff45be commit 03b66c9

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

Orange/widgets/unsupervised/owkmeans.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ class Warning(widget.OWWidget.Warning):
115115
"Too few ({}) unique data instances for {} clusters"
116116
)
117117

118-
INIT_METHODS = "Initialize with KMeans++", "Random initialization"
118+
INIT_METHODS = (("Initialize with KMeans++", "k-means++"),
119+
("Random initialization", "random"))
119120

120121
resizing_enabled = False
121122
buttons_area_orientation = Qt.Vertical
@@ -181,7 +182,7 @@ def __init__(self):
181182

182183
box = gui.vBox(self.controlArea, "Initialization")
183184
gui.comboBox(
184-
box, self, "smart_init", items=self.INIT_METHODS,
185+
box, self, "smart_init", items=[m[0] for m in self.INIT_METHODS],
185186
callback=self.invalidate)
186187

187188
layout = QGridLayout()
@@ -316,7 +317,7 @@ def __launch_tasks(self, ks):
316317
self._compute_clustering,
317318
data=self.data,
318319
k=k,
319-
init=['random', 'k-means++'][self.smart_init],
320+
init=self.INIT_METHODS[self.smart_init][1],
320321
n_init=self.n_init,
321322
max_iter=self.max_iterations,
322323
silhouette=True,
@@ -485,7 +486,7 @@ def send_report(self):
485486
k_clusters = self.k_from + self.selected_row()
486487
else:
487488
k_clusters = self.k
488-
init_method = self.INIT_METHODS[self.smart_init]
489+
init_method = self.INIT_METHODS[self.smart_init][0]
489490
init_method = init_method[0].lower() + init_method[1:]
490491
self.report_items((
491492
("Number of clusters", k_clusters),

Orange/widgets/unsupervised/tests/test_owkmeans.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,21 @@ def test_do_not_recluster_on_same_data(self):
411411
self.commit_and_wait()
412412
self.assertEqual(call_count + 1, commit.call_count)
413413

414+
def test_correct_smart_init(self):
415+
# due to a bug where wrong init was passed to _compute_clustering
416+
self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000)
417+
self.widget.smart_init = 0
418+
with patch.object(self.widget, "_compute_clustering",
419+
wraps=self.widget._compute_clustering) as compute:
420+
self.commit_and_wait()
421+
self.assertEqual(compute.call_args[1]['init'], "k-means++")
422+
self.widget.invalidate() # reset caches
423+
self.widget.smart_init = 1
424+
with patch.object(self.widget, "_compute_clustering",
425+
wraps=self.widget._compute_clustering) as compute:
426+
self.commit_and_wait()
427+
self.assertEqual(compute.call_args[1]['init'], "random")
428+
414429

415430
if __name__ == "__main__":
416431
unittest.main()

0 commit comments

Comments
 (0)