Skip to content

Commit b9ff8db

Browse files
committed
Merge pull request #1974 from janezd/fix-kmeans-fail
[FIX] KMeans: Fix crashes when underlying algorithm fails (cherry picked from commit c80741b)
1 parent c95cc7e commit b9ff8db

File tree

2 files changed

+136
-44
lines changed

2 files changed

+136
-44
lines changed

Orange/widgets/unsupervised/owkmeans.py

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
from AnyQt.QtWidgets import QGridLayout, QSizePolicy, QTableView
5-
from AnyQt.QtGui import QStandardItemModel, QStandardItem, QIntValidator
5+
from AnyQt.QtGui import QStandardItemModel, QStandardItem, QIntValidator, QBrush
66
from AnyQt.QtCore import Qt, QTimer
77

88
from Orange.clustering import KMeans
@@ -24,15 +24,18 @@ class OWKMeans(widget.OWWidget):
2424
outputs = [("Annotated Data", Table, widget.Default),
2525
("Centroids", Table)]
2626

27+
class Error(widget.OWWidget.Error):
28+
failed = widget.Msg("Clustering failed\nError: {}")
29+
2730
INIT_KMEANS, INIT_RANDOM = range(2)
2831
INIT_METHODS = "Initialize with KMeans++", "Random initialization"
2932

3033
SILHOUETTE, INTERCLUSTER, DISTANCES = range(3)
31-
SCORING_METHODS = [("Silhouette", lambda km: km.silhouette, False),
34+
SCORING_METHODS = [("Silhouette", lambda km: km.silhouette, False, True),
3235
("Inter-cluster distance",
33-
lambda km: km.inter_cluster, True),
36+
lambda km: km.inter_cluster, True, False),
3437
("Distance to centroids",
35-
lambda km: km.inertia, True)]
38+
lambda km: km.inertia, True, False)]
3639

3740
OUTPUT_CLASS, OUTPUT_ATTRIBUTE, OUTPUT_META = range(3)
3841
OUTPUT_METHODS = ("Class", "Feature", "Meta")
@@ -150,7 +153,8 @@ def __init__(self):
150153
table.setSelectionMode(QTableView.SingleSelection)
151154
table.setSelectionBehavior(QTableView.SelectRows)
152155
table.verticalHeader().hide()
153-
table.setItemDelegateForColumn(1, gui.TableBarItem(self))
156+
self.bar_delegate = gui.ColoredBarItemDelegate(self, color=Qt.cyan)
157+
table.setItemDelegateForColumn(1, self.bar_delegate)
154158
table.setModel(self.table_model)
155159
table.selectionModel().selectionChanged.connect(
156160
self.table_item_selected)
@@ -219,6 +223,7 @@ def run_optimization(self):
219223
try:
220224
self.controlArea.setDisabled(True)
221225
self.optimization_runs = []
226+
error = ""
222227
if not self.check_data_size(self.k_from, self.Error):
223228
return
224229
self.check_data_size(self.k_to, self.Warning)
@@ -231,7 +236,15 @@ def run_optimization(self):
231236
for k in range(self.k_from, k_to + 1):
232237
progress.advance()
233238
kmeans.params["n_clusters"] = k
234-
self.optimization_runs.append((k, kmeans(self.data)))
239+
try:
240+
self.optimization_runs.append((k, kmeans(self.data)))
241+
except BaseException as exc:
242+
error = str(exc)
243+
self.optimization_runs.append((k, error))
244+
if all(isinstance(score, str)
245+
for _, score in self.optimization_runs):
246+
self.Error.failed(error) # Report just the last error
247+
self.optimization_runs = []
235248
finally:
236249
self.controlArea.setDisabled(False)
237250
self.show_results()
@@ -240,11 +253,15 @@ def run_optimization(self):
240253
def cluster(self):
241254
if not self.check_data_size(self.k, self.Error):
242255
return
243-
self.km = KMeans(
244-
n_clusters=self.k,
245-
init=['random', 'k-means++'][self.smart_init],
246-
n_init=self.n_init,
247-
max_iter=self.max_iterations)(self.data)
256+
try:
257+
self.km = KMeans(
258+
n_clusters=self.k,
259+
init=['random', 'k-means++'][self.smart_init],
260+
n_init=self.n_init,
261+
max_iter=self.max_iterations)(self.data)
262+
except BaseException as exc:
263+
self.Error.failed(str(exc))
264+
self.km = None
248265
self.send_data()
249266

250267
def run(self):
@@ -260,40 +277,55 @@ def commit(self):
260277
self.run()
261278

262279
def show_results(self):
263-
minimize = self.SCORING_METHODS[self.scoring][2]
264-
k_scores = [(k, self.SCORING_METHODS[self.scoring][1](run)) for
265-
k, run in self.optimization_runs]
266-
scores = list(zip(*k_scores))[1]
267-
if minimize:
268-
best_score, worst_score = min(scores), max(scores)
280+
_, scoring_method, minimize, normal = self.SCORING_METHODS[self.scoring]
281+
k_scores = [(k,
282+
scoring_method(run) if not isinstance(run, str) else run)
283+
for k, run in self.optimization_runs]
284+
scores = [score for _, score in k_scores if not isinstance(score, str)]
285+
286+
min_score, max_score = min(scores, default=0), max(scores, default=1)
287+
best_score = min_score if minimize else max_score
288+
if normal:
289+
min_score, max_score = 0, 1
290+
nplaces = 3
269291
else:
270-
best_score, worst_score = max(scores), min(scores)
292+
nplaces = min(5, np.floor(abs(math.log(max(max_score, 1e-10)))) + 2)
293+
score_span = (max_score - min_score) or 1
294+
self.bar_delegate.scale = (min_score, max_score)
295+
self.bar_delegate.float_fmt = "%%.%if" % int(nplaces)
271296

272-
best_run = scores.index(best_score)
273-
score_span = (best_score - worst_score) or 1
274-
max_score = max(scores)
275-
nplaces = min(5, np.floor(abs(math.log(max(max_score, 1e-10)))) + 2)
276-
fmt = "{{:.{}f}}".format(int(nplaces))
277297
model = self.table_model
278298
model.setRowCount(len(k_scores))
299+
no_selection = True
279300
for i, (k, score) in enumerate(k_scores):
280-
item = model.item(i, 0)
281-
if item is None:
282-
item = QStandardItem()
283-
item.setData(k, Qt.DisplayRole)
284-
item.setTextAlignment(Qt.AlignCenter)
285-
model.setItem(i, 0, item)
286-
item = model.item(i, 1)
287-
if item is None:
288-
item = QStandardItem()
289-
item.setData(fmt.format(score) if not np.isnan(score) else 'out-of-memory error',
290-
Qt.DisplayRole)
291-
bar_ratio = 0.95 * (score - worst_score) / score_span
292-
item.setData(bar_ratio, gui.TableBarItem.BarRole)
301+
item0 = model.item(i, 0) or QStandardItem()
302+
item0.setData(k, Qt.DisplayRole)
303+
item0.setTextAlignment(Qt.AlignCenter)
304+
model.setItem(i, 0, item0)
305+
item = model.item(i, 1) or QStandardItem()
306+
if not isinstance(score, str):
307+
item.setData(score, Qt.DisplayRole)
308+
item.setData(None, Qt.ToolTipRole)
309+
bar_ratio = 0.95 * (score - min_score) / score_span
310+
item.setData(bar_ratio, gui.BarRatioRole)
311+
if no_selection and score == best_score:
312+
self.table_view.selectRow(i)
313+
no_selection = False
314+
color = Qt.black
315+
flags = Qt.ItemIsEnabled | Qt.ItemIsSelectable
316+
else:
317+
item.setData("clustering failed", Qt.DisplayRole)
318+
item.setData(score, Qt.ToolTipRole)
319+
item.setData(None, gui.BarRatioRole)
320+
color = Qt.gray
321+
flags = Qt.NoItemFlags
322+
item0.setData(QBrush(color), Qt.ForegroundRole)
323+
item0.setFlags(flags)
324+
item.setData(QBrush(color), Qt.ForegroundRole)
325+
item.setFlags(flags)
293326
model.setItem(i, 1, item)
294327
self.table_view.resizeRowsToContents()
295328

296-
self.table_view.selectRow(best_run)
297329
self.table_view.show()
298330
if minimize:
299331
self.table_box.setTitle("Scoring (smaller is better)")
@@ -314,13 +346,12 @@ def selected_row(self):
314346
def table_item_selected(self):
315347
row = self.selected_row()
316348
if row is not None:
317-
self.send_data(row)
349+
self.send_data()
318350

319-
def send_data(self, row=None):
351+
def send_data(self):
320352
if self.optimize_k:
321-
if row is None:
322-
row = self.selected_row()
323-
km = self.optimization_runs[row][1]
353+
row = self.selected_row() if self.optimization_runs else None
354+
km = self.optimization_runs[row][1] if row is not None else None
324355
else:
325356
km = self.km
326357
if not self.data or not km:
@@ -356,6 +387,8 @@ def send_data(self, row=None):
356387
def set_data(self, data):
357388
self.data = data
358389
if data is None:
390+
self.Error.clear()
391+
self.Warning.clear()
359392
self.table_model.setRowCount(0)
360393
self.send("Annotated Data", None)
361394
self.send("Centroids", None)

Orange/widgets/unsupervised/tests/test_owkmeans.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
from unittest.mock import patch
2+
13
from AnyQt.QtWidgets import QRadioButton
24

35
from Orange.widgets.tests.base import WidgetTest
46
from Orange.widgets.unsupervised.owkmeans import OWKMeans
7+
import Orange.clustering
58

69
from Orange.data import Table
710
class TestOWKMeans(WidgetTest):
811

912
def setUp(self):
10-
self.widget = self.create_widget(OWKMeans,
11-
stored_settings={"auto_apply": False})
13+
self.widget = self.create_widget(
14+
OWKMeans, stored_settings={"auto_apply": False}) # type: OWKMeans
1215
self.iris = Table("iris")
1316

1417
def test_optimization_report_display(self):
@@ -32,3 +35,59 @@ def test_data_on_output(self):
3235
self.send_signal("Data", None)
3336
# removing data should have cleared the output
3437
self.assertEqual(self.widget.data, None)
38+
39+
class KMeansFail(Orange.clustering.KMeans):
40+
fail_on = set()
41+
42+
def fit(self, *args):
43+
# when not optimizing, params is empty?!
44+
k = self.params.get("n_clusters", 3)
45+
if k in self.fail_on:
46+
raise ValueError("k={} fails".format(k))
47+
return super().fit(*args)
48+
49+
@patch("Orange.widgets.unsupervised.owkmeans.KMeans", new=KMeansFail)
50+
def test_optimization_fails(self):
51+
widget = self.widget
52+
widget.k_from = 3
53+
widget.k_to = 8
54+
widget.scoring = 0
55+
widget.optimize_k = True
56+
57+
self.KMeansFail.fail_on = {3, 5, 7}
58+
self.send_signal("Data", self.iris)
59+
self.assertIsInstance(widget.optimization_runs[0][1], str)
60+
self.assertIsInstance(widget.optimization_runs[2][1], str)
61+
self.assertIsInstance(widget.optimization_runs[4][1], str)
62+
self.assertNotIsInstance(widget.optimization_runs[1][1], str)
63+
self.assertNotIsInstance(widget.optimization_runs[3][1], str)
64+
self.assertNotIsInstance(widget.optimization_runs[5][1], str)
65+
self.assertFalse(widget.Error.failed.is_shown())
66+
self.assertEqual(widget.selected_row(), 1)
67+
self.assertIsNotNone(self.get_output("Annotated Data"))
68+
69+
self.KMeansFail.fail_on = set(range(3, 9))
70+
widget.run()
71+
self.assertTrue(widget.Error.failed.is_shown())
72+
self.assertEqual(widget.optimization_runs, [])
73+
self.assertIsNone(self.get_output("Annotated Data"))
74+
75+
self.KMeansFail.fail_on = set()
76+
widget.run()
77+
self.assertFalse(widget.Error.failed.is_shown())
78+
self.assertEqual(widget.selected_row(), 0)
79+
self.assertIsNotNone(self.get_output("Annotated Data"))
80+
81+
@patch("Orange.widgets.unsupervised.owkmeans.KMeans", new=KMeansFail)
82+
def test_run_fails(self):
83+
self.widget.k = 3
84+
self.widget.optimize_k = False
85+
self.KMeansFail.fail_on = {3}
86+
self.send_signal("Data", self.iris)
87+
self.assertTrue(self.widget.Error.failed.is_shown())
88+
self.assertIsNone(self.get_output("Annotated Data"))
89+
90+
self.KMeansFail.fail_on = set()
91+
self.widget.run()
92+
self.assertFalse(self.widget.Error.failed.is_shown())
93+
self.assertIsNotNone(self.get_output("Annotated Data"))

0 commit comments

Comments
 (0)