Skip to content

Commit f1e9e53

Browse files
committed
Distances: Offload work to a separate thread
1 parent 774c4be commit f1e9e53

File tree

2 files changed

+95
-22
lines changed

2 files changed

+95
-22
lines changed

Orange/widgets/unsupervised/owdistances.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import Orange.misc
77
from Orange import distance
88
from Orange.widgets import gui, settings
9+
from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin
910
from Orange.widgets.utils.sql import check_sql_input
1011
from Orange.widgets.utils.widgetpreview import WidgetPreview
1112
from Orange.widgets.widget import OWWidget, Msg, Input, Output
@@ -26,7 +27,22 @@
2627
]
2728

2829

29-
class OWDistances(OWWidget):
30+
class DistanceRunner:
31+
@staticmethod
32+
def run(data: Orange.data.Table, metric: distance, normalized_dist: bool,
33+
axis: int, state: TaskState) -> Orange.misc.DistMatrix:
34+
if data is None:
35+
return None
36+
37+
state.set_status("Calculating...")
38+
if metric.supports_normalization and normalized_dist:
39+
return metric(data, axis=1 - axis, impute=True,
40+
normalize=True)
41+
else:
42+
return metric(data, axis=1 - axis, impute=True)
43+
44+
45+
class OWDistances(OWWidget, ConcurrentWidgetMixin):
3046
name = "Distances"
3147
description = "Compute a matrix of pairwise distances."
3248
icon = "icons/Distance.svg"
@@ -65,13 +81,15 @@ class Warning(OWWidget.Warning):
6581
imputing_data = Msg("Missing values were imputed")
6682

6783
def __init__(self):
68-
super().__init__()
84+
OWWidget.__init__(self)
85+
ConcurrentWidgetMixin.__init__(self)
6986

7087
self.data = None
7188

72-
gui.radioButtons(self.controlArea, self, "axis", ["Rows", "Columns"],
73-
box="Distances between", callback=self._invalidate
74-
)
89+
gui.radioButtons(
90+
self.controlArea, self, "axis", ["Rows", "Columns"],
91+
box="Distances between", callback=self._invalidate
92+
)
7593
box = gui.widgetBox(self.controlArea, "Distance Metric")
7694
self.metrics_combo = gui.comboBox(
7795
box, self, "metric_idx",
@@ -93,6 +111,7 @@ def __init__(self):
93111
@Inputs.data
94112
@check_sql_input
95113
def set_data(self, data):
114+
self.cancel()
96115
self.data = data
97116
self.refresh_metrics()
98117
self.unconditional_commit()
@@ -106,8 +125,7 @@ def refresh_metrics(self):
106125
def commit(self):
107126
# pylint: disable=invalid-sequence-index
108127
metric = METRICS[self.metric_idx][1]
109-
dist = self.compute_distances(metric, self.data)
110-
self.Outputs.distances.send(dist)
128+
self.compute_distances(metric, self.data)
111129

112130
def compute_distances(self, metric, data):
113131
def _check_sparse():
@@ -152,22 +170,34 @@ def _fix_missing():
152170
return True
153171

154172
self.clear_messages()
155-
if data is None:
156-
return None
157-
for check in (_check_sparse,
158-
_fix_discrete, _fix_missing, _fix_nonbinary):
159-
if not check():
160-
return None
161-
try:
162-
if metric.supports_normalization and self.normalized_dist:
163-
return metric(data, axis=1 - self.axis, impute=True,
164-
normalize=True)
165-
else:
166-
return metric(data, axis=1 - self.axis, impute=True)
167-
except ValueError as e:
173+
if data is not None:
174+
for check in (_check_sparse, _fix_discrete,
175+
_fix_missing, _fix_nonbinary):
176+
if not check():
177+
data = None
178+
break
179+
180+
self.start(DistanceRunner.run, data, metric,
181+
self.normalized_dist, self.axis)
182+
183+
def on_partial_result(self, _):
184+
pass
185+
186+
def on_done(self, dist: Orange.misc.DistMatrix):
187+
assert isinstance(dist, Orange.misc.DistMatrix) or dist is None
188+
self.Outputs.distances.send(dist)
189+
190+
def on_exception(self, e):
191+
if isinstance(e, ValueError):
168192
self.Error.distances_value_error(e)
169-
except MemoryError:
193+
elif isinstance(e, MemoryError):
170194
self.Error.distances_memory_error()
195+
else:
196+
raise e
197+
198+
def onDeleteWidget(self):
199+
self.shutdown()
200+
super().onDeleteWidget()
171201

172202
def _invalidate(self):
173203
self.commit()

Orange/widgets/unsupervised/tests/test_owdistances.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,46 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
import unittest
34
from unittest.mock import Mock
45

56
import numpy as np
67

78
from Orange.data import Table, Domain
89
from Orange import distance
9-
from Orange.widgets.unsupervised.owdistances import OWDistances, METRICS
10+
from Orange.widgets.unsupervised.owdistances import OWDistances, METRICS, \
11+
DistanceRunner
1012
from Orange.widgets.tests.base import WidgetTest
1113

1214

15+
class TestDistanceRunner(unittest.TestCase):
16+
@classmethod
17+
def setUpClass(cls):
18+
super().setUpClass()
19+
cls.iris = Table("iris")[::5]
20+
21+
def test_run(self):
22+
for _, metric in METRICS:
23+
# between rows, normalized
24+
dist1 = DistanceRunner.run(self.iris, metric, True, 0, Mock())
25+
dist2 = metric(self.iris, axis=1, impute=True, normalize=True)
26+
np.testing.assert_array_equal(dist1, dist2)
27+
28+
# between rows, not normalized
29+
dist1 = DistanceRunner.run(self.iris, metric, False, 0, Mock())
30+
dist2 = metric(self.iris, axis=1, impute=True, normalize=False)
31+
np.testing.assert_array_equal(dist1, dist2)
32+
33+
# between columns, normalized
34+
dist1 = DistanceRunner.run(self.iris, metric, True, 1, Mock())
35+
dist2 = metric(self.iris, axis=0, impute=True, normalize=True)
36+
np.testing.assert_array_equal(dist1, dist2)
37+
38+
# between columns, not normalized
39+
dist1 = DistanceRunner.run(self.iris, metric, False, 1, Mock())
40+
dist2 = metric(self.iris, axis=0, impute=True, normalize=False)
41+
np.testing.assert_array_equal(dist1, dist2)
42+
43+
1344
class TestOWDistances(WidgetTest):
1445
@classmethod
1546
def setUpClass(cls):
@@ -27,6 +58,7 @@ def test_distance_combo(self):
2758
for i, (_, metric) in enumerate(METRICS):
2859
self.widget.metrics_combo.activated.emit(i)
2960
self.widget.metrics_combo.setCurrentIndex(i)
61+
self.wait_until_stop_blocking()
3062
self.send_signal(self.widget.Inputs.data, self.iris)
3163
if metric.supports_normalization:
3264
expected = metric(self.iris, normalize=self.widget.normalized_dist)
@@ -42,8 +74,10 @@ def test_error_message(self):
4274
data is removed from input"""
4375
self.widget.metric_idx = 2
4476
self.send_signal(self.widget.Inputs.data, self.iris)
77+
self.wait_until_stop_blocking()
4578
self.assertFalse(self.widget.Error.no_continuous_features.is_shown())
4679
self.send_signal(self.widget.Inputs.data, self.titanic)
80+
self.wait_until_stop_blocking()
4781
self.assertTrue(self.widget.Error.no_continuous_features.is_shown())
4882
self.send_signal(self.widget.Inputs.data, None)
4983
self.assertFalse(self.widget.Error.no_continuous_features.is_shown())
@@ -53,32 +87,39 @@ def test_jaccard_messages(self):
5387
if name == "Jaccard":
5488
break
5589
self.send_signal(self.widget.Inputs.data, self.iris)
90+
self.wait_until_stop_blocking()
5691
self.assertTrue(self.widget.Error.no_binary_features.is_shown())
5792
self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown())
5893

5994
self.send_signal(self.widget.Inputs.data, None)
95+
self.wait_until_stop_blocking()
6096
self.assertFalse(self.widget.Error.no_binary_features.is_shown())
6197
self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown())
6298

6399
self.send_signal(self.widget.Inputs.data, self.titanic)
100+
self.wait_until_stop_blocking()
64101
self.assertFalse(self.widget.Error.no_binary_features.is_shown())
65102
self.assertTrue(self.widget.Warning.ignoring_nonbinary.is_shown())
66103

67104
self.send_signal(self.widget.Inputs.data, None)
105+
self.wait_until_stop_blocking()
68106
self.assertFalse(self.widget.Error.no_binary_features.is_shown())
69107
self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown())
70108

71109
self.send_signal(self.widget.Inputs.data, self.titanic)
110+
self.wait_until_stop_blocking()
72111
self.assertFalse(self.widget.Error.no_binary_features.is_shown())
73112
self.assertTrue(self.widget.Warning.ignoring_nonbinary.is_shown())
74113

75114
dom = self.titanic.domain
76115
dom = Domain(dom.attributes[1:], dom.class_var)
77116
self.send_signal(self.widget.Inputs.data, self.titanic.transform(dom))
117+
self.wait_until_stop_blocking()
78118
self.assertFalse(self.widget.Error.no_binary_features.is_shown())
79119
self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown())
80120

81121
self.send_signal(self.widget.Inputs.data, Table("heart_disease"))
122+
self.wait_until_stop_blocking()
82123
self.assertFalse(self.widget.Error.no_binary_features.is_shown())
83124
self.assertFalse(self.widget.Warning.ignoring_discrete.is_shown())
84125

@@ -93,10 +134,12 @@ def test_too_big_array(self):
93134

94135
mock = Mock(side_effect=ValueError)
95136
self.widget.compute_distances(mock, self.iris)
137+
self.wait_until_stop_blocking()
96138
self.assertTrue(self.widget.Error.distances_value_error.is_shown())
97139

98140
mock = Mock(side_effect=MemoryError)
99141
self.widget.compute_distances(mock, self.iris)
142+
self.wait_until_stop_blocking()
100143
self.assertEqual(len(self.widget.Error.active), 1)
101144
self.assertTrue(self.widget.Error.distances_memory_error.is_shown())
102145

0 commit comments

Comments
 (0)