Skip to content

Commit a283eb1

Browse files
authored
Merge pull request #4082 from VesnaT/kmeans_selection
[FIX] K-means: Save Silhouette Scores selection
2 parents e1f5e9f + a193f02 commit a283eb1

File tree

2 files changed

+51
-24
lines changed

2 files changed

+51
-24
lines changed

Orange/widgets/unsupervised/owkmeans.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class Warning(widget.OWWidget.Warning):
142142
max_iterations = Setting(300)
143143
n_init = Setting(10)
144144
smart_init = Setting(0) # KMeans++
145+
selection = Setting(None, schema_only=True) # type: Optional[int]
145146
auto_commit = Setting(True)
146147

147148
settings_version = 2
@@ -158,6 +159,7 @@ def __init__(self):
158159
super().__init__()
159160

160161
self.data = None # type: Optional[Table]
162+
self.__pending_selection = self.selection # type: Optional[int]
161163
self.clusterings = {}
162164

163165
self.__executor = ThreadExecutor(parent=self)
@@ -443,17 +445,25 @@ def update_results(self):
443445
key=lambda x: 0 if isinstance(scores[x], str) else scores[x]
444446
)
445447
self.table_model.set_scores(scores, self.k_from)
446-
self.table_view.selectRow(best_row)
448+
self.apply_selection(best_row)
447449
self.table_view.setFocus(Qt.OtherFocusReason)
448450
self.table_view.resizeRowsToContents()
449451

452+
def apply_selection(self, best_row):
453+
pending = best_row
454+
if self.__pending_selection is not None:
455+
pending = self.__pending_selection
456+
self.__pending_selection = None
457+
self.table_view.selectRow(pending)
458+
450459
def selected_row(self):
451460
indices = self.table_view.selectedIndexes()
452461
if not indices:
453462
return None
454463
return indices[0].row()
455464

456465
def select_row(self):
466+
self.selection = self.selected_row()
457467
self.send_data()
458468

459469
def preproces(self, data):
@@ -535,6 +545,7 @@ def send_data(self):
535545
@check_sql_input
536546
def set_data(self, data):
537547
self.data, old_data = data, self.data
548+
self.selection = None
538549

539550
# Do not needlessly recluster the data if X hasn't changed
540551
if old_data and self.data and array_equal(self.data.X, old_data.X):

Orange/widgets/unsupervised/tests/test_owkmeans.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def setUp(self):
4444
self.widget = self.create_widget(
4545
OWKMeans, stored_settings={"auto_commit": False, "version": 2}
4646
) # type: OWKMeans
47-
self.iris = Table("iris")
48-
self.iris.X[0, 0] = np.nan
47+
self.data = Table("heart_disease")
4948

5049
def tearDown(self):
5150
self.widget.onDeleteWidget()
@@ -61,7 +60,7 @@ def test_migrate_version_1_settings(self):
6160
def test_optimization_report_display(self):
6261
"""Check visibility of the table after selecting number of clusters"""
6362
self.widget.auto_commit = True
64-
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
63+
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
6564
self.widget.optimize_k = True
6665
radio_buttons = self.widget.controls.optimize_k.findChildren(QRadioButton)
6766

@@ -80,7 +79,7 @@ def test_optimization_report_display(self):
8079
def test_changing_k_changes_radio(self):
8180
widget = self.widget
8281
widget.auto_commit = True
83-
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
82+
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
8483

8584
widget.optimize_k = True
8685

@@ -118,7 +117,7 @@ def test_no_data_hides_main_area(self):
118117

119118
self.send_signal(self.widget.Inputs.data, None, wait=5000)
120119
self.assertTrue(self.widget.mainArea.isHidden())
121-
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
120+
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
122121
self.assertFalse(self.widget.mainArea.isHidden())
123122
self.send_signal(self.widget.Inputs.data, None, wait=5000)
124123
self.assertTrue(self.widget.mainArea.isHidden())
@@ -129,7 +128,7 @@ def test_data_limits(self):
129128
widget = self.widget
130129
widget.auto_commit = False
131130

132-
self.send_signal(self.widget.Inputs.data, self.iris[:5])
131+
self.send_signal(self.widget.Inputs.data, self.data[:5])
133132

134133
widget.k = 10
135134
self.commit_and_wait()
@@ -159,7 +158,7 @@ def test_use_cache(self):
159158
"""Cache various clusterings for the dataset until data changes."""
160159
widget = self.widget
161160
widget.auto_commit = False
162-
self.send_signal(self.widget.Inputs.data, self.iris)
161+
self.send_signal(self.widget.Inputs.data, self.data)
163162

164163
with patch.object(widget, "_compute_clustering",
165164
wraps=widget._compute_clustering) as compute:
@@ -191,7 +190,7 @@ def test_use_cache(self):
191190
def test_data_on_output(self):
192191
"""Check if data is on output after create widget and run"""
193192
self.widget.auto_commit = True
194-
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
193+
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
195194
self.widget.apply_button.button.click()
196195
self.assertNotEqual(self.widget.data, None)
197196
# Disconnect the data
@@ -203,19 +202,19 @@ def test_centroids_on_output(self):
203202
widget = self.widget
204203
widget.optimize_k = False
205204
widget.k = 4
206-
self.send_signal(widget.Inputs.data, self.iris)
205+
self.send_signal(widget.Inputs.data, self.data)
207206
self.commit_and_wait()
208-
widget.clusterings[widget.k].labels = np.array([0] * 50 + [1] * 100).flatten()
207+
widget.clusterings[widget.k].labels = np.array([0] * 100 + [1] * 203).flatten()
209208

210209
widget.samples_scores = lambda x: np.arctan(
211-
np.arange(150) / 150) / np.pi + 0.5
210+
np.arange(303) / 303) / np.pi + 0.5
212211
widget.send_data()
213212
out = self.get_output(widget.Outputs.centroids)
214213
np.testing.assert_array_almost_equal(
215-
np.array([[0, np.mean(np.arctan(np.arange(50) / 150)) / np.pi + 0.5],
216-
[1, np.mean(np.arctan(np.arange(50, 150) / 150)) / np.pi + 0.5],
214+
np.array([[0, np.mean(np.arctan(np.arange(100) / 303)) / np.pi + 0.5],
215+
[1, np.mean(np.arctan(np.arange(100, 303) / 303)) / np.pi + 0.5],
217216
[2, 0], [3, 0]]), out.metas.astype(float))
218-
self.assertEqual(out.name, "iris centroids")
217+
self.assertEqual(out.name, "heart_disease centroids")
219218

220219
def test_centroids_domain_on_output(self):
221220
widget = self.widget
@@ -262,13 +261,13 @@ def test_optimization_fails(self):
262261

263262
with patch.object(
264263
model, "set_scores", wraps=model.set_scores) as set_scores:
265-
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
264+
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
266265
scores, start_k = set_scores.call_args[0]
267-
X = self.widget.preproces(self.iris).X
266+
X = self.widget.preproces(self.data).X
268267
self.assertEqual(
269268
scores,
270269
[km if isinstance(km, str) else silhouette_score(
271-
X, km(self.iris))
270+
X, km(self.data))
272271
for km in (widget.clusterings[k] for k in range(3, 9))]
273272
)
274273
self.assertEqual(start_k, 3)
@@ -302,7 +301,7 @@ def test_run_fails(self):
302301
self.widget.auto_commit = True
303302
self.widget.optimize_k = False
304303
self.KMeansFail.fail_on = {3}
305-
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
304+
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
306305
self.assertTrue(self.widget.Error.failed.is_shown())
307306
self.assertIsNone(self.get_output(self.widget.Outputs.annotated_data))
308307

@@ -362,7 +361,7 @@ def test_not_enough_rows(self):
362361
Widget should not crash when there is less rows than k_from.
363362
GH-2172
364363
"""
365-
table = self.iris[0:1, :]
364+
table = self.data[0:1, :]
366365
self.widget.controls.k_from.setValue(2)
367366
self.widget.controls.k_to.setValue(9)
368367
self.send_signal(self.widget.Inputs.data, table)
@@ -374,7 +373,7 @@ def test_from_to_table(self):
374373
"""
375374
k_from, k_to = 2, 9
376375
self.widget.controls.k_from.setValue(k_from)
377-
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
376+
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
378377
check = lambda x: 2 if x - k_from + 1 < 2 else x - k_from + 1
379378
for i in range(k_from, k_to):
380379
self.widget.controls.k_to.setValue(i)
@@ -415,7 +414,7 @@ def test_invalidate_clusterings_cancels_jobs(self):
415414
widget.auto_commit = False
416415

417416
# Send the data without waiting
418-
self.send_signal(widget.Inputs.data, self.iris)
417+
self.send_signal(widget.Inputs.data, self.data)
419418
widget.unconditional_commit()
420419
# Now, invalidate by changing max_iter
421420
widget.max_iterations = widget.max_iterations + 1
@@ -460,7 +459,7 @@ def test_do_not_recluster_on_same_data(self):
460459

461460
def test_correct_smart_init(self):
462461
# due to a bug where wrong init was passed to _compute_clustering
463-
self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000)
462+
self.send_signal(self.widget.Inputs.data, self.data[::10], wait=5000)
464463
self.widget.smart_init = 0
465464
self.widget.clusterings = {}
466465
with patch.object(self.widget, "_compute_clustering",
@@ -476,7 +475,7 @@ def test_correct_smart_init(self):
476475

477476
def test_always_same_cluster(self):
478477
"""The same random state should always return the same clusters"""
479-
self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000)
478+
self.send_signal(self.widget.Inputs.data, self.data[::10], wait=5000)
480479

481480
def cluster():
482481
self.widget.invalidate() # reset caches
@@ -500,6 +499,23 @@ def test_error_no_attributes(self):
500499
self.send_signal(self.widget.Inputs.data, table)
501500
self.assertTrue(self.widget.Error.no_attributes.is_shown())
502501

502+
def test_saved_selection(self):
503+
self.widget.send_data = Mock()
504+
self.widget.optimize_k = True
505+
self.send_signal(self.widget.Inputs.data, self.data)
506+
self.wait_until_stop_blocking()
507+
self.widget.table_view.selectRow(2)
508+
self.assertEqual(self.widget.selected_row(), 2)
509+
self.assertEqual(self.widget.send_data.call_count, 3)
510+
settings = self.widget.settingsHandler.pack_data(self.widget)
511+
512+
w = self.create_widget(OWKMeans, stored_settings=settings)
513+
w.send_data = Mock()
514+
self.send_signal(w.Inputs.data, self.data, widget=w)
515+
self.wait_until_stop_blocking(widget=w)
516+
self.assertEqual(w.send_data.call_count, 2)
517+
self.assertEqual(self.widget.selected_row(), w.selected_row())
518+
503519

504520
if __name__ == "__main__":
505521
unittest.main()

0 commit comments

Comments
 (0)