Skip to content

Commit a7f3940

Browse files
committed
Distances: Interruptible task
1 parent f1e9e53 commit a7f3940

File tree

4 files changed

+98
-49
lines changed

4 files changed

+98
-49
lines changed

Orange/distance/base.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class Distance:
116116
impute (bool):
117117
if `True` (default is `False`), nans in the computed distances
118118
are replaced with zeros, and infs with very large numbers.
119+
callback (callable or None):
120+
callback function
119121
120122
Attributes:
121123
axis (int):
@@ -162,10 +164,12 @@ class Distance:
162164
axis = 1
163165
impute = False
164166

165-
def __new__(cls, e1=None, e2=None, axis=1, impute=False, **kwargs):
167+
def __new__(cls, e1=None, e2=None, axis=1, impute=False,
168+
callback=None, **kwargs):
166169
self = super().__new__(cls)
167170
self.axis = axis
168171
self.impute = impute
172+
self.callback = callback
169173
# Ugly, but needed to allow allow setting subclass-specific parameters
170174
# (such as normalize) when `e1` is not `None` and the `__new__` in the
171175
# subclass is skipped
@@ -225,11 +229,14 @@ class DistanceModel:
225229
impute (bool):
226230
if `True` (default is `False`), nans in the computed distances
227231
are replaced with zeros, and infs with very large numbers
232+
callback (callable or None):
233+
callback function
228234
229235
"""
230-
def __init__(self, axis, impute=False):
236+
def __init__(self, axis, impute=False, callback=None):
231237
self._axis = axis
232238
self.impute = impute
239+
self.callback = callback
233240

234241
@property
235242
def axis(self):
@@ -291,9 +298,10 @@ class FittedDistanceModel(DistanceModel):
291298
continuous (np.ndarray): bool array indicating continuous attributes
292299
normalize (bool):
293300
if `True` (default is `False`) continuous columns are normalized
301+
callback (callable or None): callback function
294302
"""
295-
def __init__(self, attributes, axis=1, impute=False):
296-
super().__init__(axis, impute)
303+
def __init__(self, attributes, axis=1, impute=False, callback=None):
304+
super().__init__(axis, impute, callback)
297305
self.attributes = attributes
298306
self.discrete = None
299307
self.continuous = None
@@ -464,7 +472,8 @@ def fit_rows(self, attributes, x, n_vals):
464472
continuous, discrete,
465473
offsets[:curr_cont], scales[:curr_cont],
466474
dist_missing2_cont[:curr_cont],
467-
dist_missing_disc, dist_missing2_disc)
475+
dist_missing_disc, dist_missing2_disc,
476+
self.callback)
468477

469478
@staticmethod
470479
def get_discrete_stats(column, n_bins):

Orange/distance/distance.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class EuclideanRowsModel(FittedDistanceModel):
2626
def __init__(self, attributes, impute, normalize,
2727
continuous, discrete,
2828
means, stdvars, dist_missing2_cont,
29-
dist_missing_disc, dist_missing2_disc):
30-
super().__init__(attributes, 1, impute)
29+
dist_missing_disc, dist_missing2_disc, callback):
30+
super().__init__(attributes, 1, impute, callback)
3131
self.normalize = normalize
3232
self.continuous = continuous
3333
self.discrete = discrete
@@ -93,8 +93,9 @@ class EuclideanColumnsModel(FittedDistanceModel):
9393
Means are used as offsets for normalization, and two deviations are
9494
used for scaling.
9595
"""
96-
def __init__(self, attributes, impute, normalize, means, stdvars):
97-
super().__init__(attributes, 0, impute)
96+
def __init__(self, attributes, impute, normalize, means, stdvars,
97+
callback=None):
98+
super().__init__(attributes, 0, impute, callback)
9899
self.normalize = normalize
99100
self.means = means
100101
self.vars = stdvars
@@ -135,9 +136,11 @@ class Euclidean(FittedDistance):
135136
fallback = SklDistance('euclidean')
136137
rows_model_type = EuclideanRowsModel
137138

138-
def __new__(cls, e1=None, e2=None, axis=1, impute=False, normalize=False):
139+
def __new__(cls, e1=None, e2=None, axis=1, impute=False, normalize=False,
140+
callback=None):
139141
# pylint: disable=arguments-differ
140-
return super().__new__(cls, e1, e2, axis, impute, normalize=normalize)
142+
return super().__new__(cls, e1, e2, axis, impute, callback,
143+
normalize=normalize)
141144

142145
def get_continuous_stats(self, column):
143146
"""
@@ -180,7 +183,8 @@ def nowarn(msg, cat, *args, **kwargs):
180183
if self.normalize and not stdvars.all():
181184
raise ValueError("some columns are constant")
182185
return EuclideanColumnsModel(
183-
attributes, self.impute, self.normalize, means, stdvars)
186+
attributes, self.impute, self.normalize, means, stdvars,
187+
self.callback)
184188

185189

186190
class ManhattanRowsModel(FittedDistanceModel):
@@ -193,8 +197,8 @@ class ManhattanRowsModel(FittedDistanceModel):
193197
def __init__(self, attributes, impute, normalize,
194198
continuous, discrete,
195199
medians, mads, dist_missing2_cont,
196-
dist_missing_disc, dist_missing2_disc):
197-
super().__init__(attributes, 1, impute)
200+
dist_missing_disc, dist_missing2_disc, callback=None):
201+
super().__init__(attributes, 1, impute, callback)
198202
self.normalize = normalize
199203
self.continuous = continuous
200204
self.discrete = discrete
@@ -250,8 +254,9 @@ class ManhattanColumnsModel(FittedDistanceModel):
250254
used for scaling.
251255
"""
252256

253-
def __init__(self, attributes, impute, normalize, medians, mads):
254-
super().__init__(attributes, 0, impute)
257+
def __init__(self, attributes, impute, normalize, medians, mads,
258+
callback=None):
259+
super().__init__(attributes, 0, impute, callback)
255260
self.normalize = normalize
256261
self.medians = medians
257262
self.mads = mads
@@ -271,9 +276,11 @@ class Manhattan(FittedDistance):
271276
fallback = SklDistance('manhattan')
272277
rows_model_type = ManhattanRowsModel
273278

274-
def __new__(cls, e1=None, e2=None, axis=1, impute=False, normalize=False):
279+
def __new__(cls, e1=None, e2=None, axis=1, impute=False, normalize=False,
280+
callback=None):
275281
# pylint: disable=arguments-differ
276-
return super().__new__(cls, e1, e2, axis, impute, normalize=normalize)
282+
return super().__new__(cls, e1, e2, axis, impute, callback,
283+
normalize=normalize)
277284

278285
def get_continuous_stats(self, column):
279286
"""
@@ -310,7 +317,8 @@ def fit_cols(self, attributes, x, n_vals):
310317
"some columns have zero absolute distance from median, "
311318
"or no values")
312319
return ManhattanColumnsModel(
313-
attributes, self.impute, self.normalize, medians, mads)
320+
attributes, self.impute, self.normalize, medians, mads,
321+
self.callback)
314322

315323

316324
class Cosine(FittedDistance):

Orange/widgets/unsupervised/owdistances.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import Orange.data
66
import Orange.misc
77
from Orange import distance
8-
from Orange.widgets import gui, settings
8+
from Orange.widgets import gui
9+
from Orange.widgets.settings import Setting
910
from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin
1011
from Orange.widgets.utils.sql import check_sql_input
1112
from Orange.widgets.utils.widgetpreview import WidgetPreview
@@ -34,12 +35,15 @@ def run(data: Orange.data.Table, metric: distance, normalized_dist: bool,
3435
if data is None:
3536
return None
3637

38+
def callback(i: float) -> bool: # return True if interrupt requested
39+
state.set_progress_value(i)
40+
return state.is_interruption_requested()
41+
3742
state.set_status("Calculating...")
43+
kwargs = {"axis": 1 - axis, "impute": True, "callback": callback}
3844
if metric.supports_normalization and normalized_dist:
39-
return metric(data, axis=1 - axis, impute=True,
40-
normalize=True)
41-
else:
42-
return metric(data, axis=1 - axis, impute=True)
45+
kwargs["normalize"] = True
46+
return metric(data, **kwargs)
4347

4448

4549
class OWDistances(OWWidget, ConcurrentWidgetMixin):
@@ -56,14 +60,14 @@ class Outputs:
5660

5761
settings_version = 2
5862

59-
axis = settings.Setting(0) # type: int
60-
metric_idx = settings.Setting(0) # type: int
63+
axis = Setting(0) # type: int
64+
metric_idx = Setting(0) # type: int
6165

6266
#: Use normalized distances if the metric supports it.
6367
#: The default is `True`, expect when restoring from old pre v2 settings
6468
#: (see `migrate_settings`).
65-
normalized_dist = settings.Setting(True) # type: bool
66-
autocommit = settings.Setting(True) # type: bool
69+
normalized_dist = Setting(True) # type: bool
70+
autocommit = Setting(True) # type: bool
6771

6872
want_main_area = False
6973
buttons_area_orientation = Qt.Vertical
@@ -187,13 +191,13 @@ def on_done(self, dist: Orange.misc.DistMatrix):
187191
assert isinstance(dist, Orange.misc.DistMatrix) or dist is None
188192
self.Outputs.distances.send(dist)
189193

190-
def on_exception(self, e):
191-
if isinstance(e, ValueError):
192-
self.Error.distances_value_error(e)
193-
elif isinstance(e, MemoryError):
194+
def on_exception(self, ex):
195+
if isinstance(ex, ValueError):
196+
self.Error.distances_value_error(ex)
197+
elif isinstance(ex, MemoryError):
194198
self.Error.distances_memory_error()
195199
else:
196-
raise e
200+
raise ex
197201

198202
def onDeleteWidget(self):
199203
self.shutdown()

Orange/widgets/unsupervised/tests/test_owdistances.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import numpy as np
77

8-
from Orange.data import Table, Domain
98
from Orange import distance
9+
from Orange.data import Table, Domain
10+
from Orange.misc import DistMatrix
1011
from Orange.widgets.unsupervised.owdistances import OWDistances, METRICS, \
1112
DistanceRunner
1213
from Orange.widgets.tests.base import WidgetTest
@@ -17,28 +18,51 @@ class TestDistanceRunner(unittest.TestCase):
1718
def setUpClass(cls):
1819
super().setUpClass()
1920
cls.iris = Table("iris")[::5]
21+
cls.iris.X[0, 2] = np.nan
22+
cls.iris.X[1, 3] = np.nan
23+
cls.iris.X[2, 1] = np.nan
24+
cls.zoo = Table("zoo")[::5]
25+
cls.zoo.X[0, 2] = np.nan
26+
cls.zoo.X[1, 3] = np.nan
27+
cls.zoo.X[2, 1] = np.nan
2028

2129
def test_run(self):
22-
for _, metric in METRICS:
30+
state = Mock()
31+
state.is_interruption_requested = Mock(return_value=False)
32+
for name, metric in METRICS:
33+
data = self.iris
34+
if not metric.supports_missing:
35+
data = distance.impute(data)
36+
elif name == "Jaccard":
37+
data = self.zoo
38+
2339
# between rows, normalized
24-
dist1 = DistanceRunner.run(self.iris, metric, True, 0, Mock())
25-
dist2 = metric(self.iris, axis=1, impute=True, normalize=True)
26-
np.testing.assert_array_equal(dist1, dist2)
40+
dist1 = DistanceRunner.run(data, metric, True, 0, state)
41+
dist2 = metric(data, axis=1, impute=True, normalize=True)
42+
self.assertDistMatrixEqual(dist1, dist2)
2743

2844
# between rows, not normalized
29-
dist1 = DistanceRunner.run(self.iris, metric, False, 0, Mock())
30-
dist2 = metric(self.iris, axis=1, impute=True, normalize=False)
31-
np.testing.assert_array_equal(dist1, dist2)
45+
dist1 = DistanceRunner.run(data, metric, False, 0, state)
46+
dist2 = metric(data, axis=1, impute=True, normalize=False)
47+
self.assertDistMatrixEqual(dist1, dist2)
3248

3349
# between columns, normalized
34-
dist1 = DistanceRunner.run(self.iris, metric, True, 1, Mock())
35-
dist2 = metric(self.iris, axis=0, impute=True, normalize=True)
36-
np.testing.assert_array_equal(dist1, dist2)
50+
dist1 = DistanceRunner.run(data, metric, True, 1, state)
51+
dist2 = metric(data, axis=0, impute=True, normalize=True)
52+
self.assertDistMatrixEqual(dist1, dist2)
3753

3854
# between columns, not normalized
39-
dist1 = DistanceRunner.run(self.iris, metric, False, 1, Mock())
40-
dist2 = metric(self.iris, axis=0, impute=True, normalize=False)
41-
np.testing.assert_array_equal(dist1, dist2)
55+
dist1 = DistanceRunner.run(data, metric, False, 1, state)
56+
dist2 = metric(data, axis=0, impute=True, normalize=False)
57+
self.assertDistMatrixEqual(dist1, dist2)
58+
59+
def assertDistMatrixEqual(self, dist1, dist2):
60+
self.assertIsInstance(dist1, DistMatrix)
61+
self.assertIsInstance(dist2, DistMatrix)
62+
self.assertEqual(dist1.axis, dist2.axis)
63+
self.assertEqual(dist1.row_items, dist2.row_items)
64+
self.assertEqual(dist1.col_items, dist2.col_items)
65+
np.testing.assert_array_almost_equal(dist1, dist2)
4266

4367

4468
class TestOWDistances(WidgetTest):
@@ -59,14 +83,13 @@ def test_distance_combo(self):
5983
self.widget.metrics_combo.activated.emit(i)
6084
self.widget.metrics_combo.setCurrentIndex(i)
6185
self.wait_until_stop_blocking()
62-
self.send_signal(self.widget.Inputs.data, self.iris)
6386
if metric.supports_normalization:
6487
expected = metric(self.iris, normalize=self.widget.normalized_dist)
6588
else:
6689
expected = metric(self.iris)
6790

6891
if metric is not distance.Jaccard:
69-
np.testing.assert_array_equal(
92+
np.testing.assert_array_almost_equal(
7093
expected, self.get_output(self.widget.Outputs.distances))
7194

7295
def test_error_message(self):
@@ -155,3 +178,8 @@ def test_negative_values_bhattacharyya(self):
155178
self.send_signal(self.widget.Inputs.data, self.iris)
156179
self.assertTrue(self.widget.Error.distances_value_error.is_shown())
157180
self.iris.X[0, 0] *= -1
181+
182+
183+
if __name__ == "__main__":
184+
unittest.main()
185+

0 commit comments

Comments
 (0)