Skip to content

Commit 687de8f

Browse files
committed
Distances: Interruptible task
1 parent 6b2bb3c commit 687de8f

File tree

2 files changed

+57
-20
lines changed

2 files changed

+57
-20
lines changed

Orange/widgets/unsupervised/owdistances.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from AnyQt.QtCore import Qt
23
from scipy.sparse import issparse
34
import bottleneck as bn
@@ -34,11 +35,24 @@ def run(data: Orange.data.Table, metric: distance, normalized_dist: bool,
3435
return None
3536

3637
state.set_status("Calculating...")
38+
axis = 1 - axis
39+
kwargs = {"axis": axis, "impute": True}
3740
if metric.supports_normalization and normalized_dist:
38-
return metric(data, axis=1 - axis, impute=True,
39-
normalize=True)
40-
else:
41-
return metric(data, axis=1 - axis, impute=True)
41+
kwargs["normalize"] = True
42+
43+
if axis == 1: # rows
44+
n_data = len(data)
45+
dist = np.empty((n_data, n_data), dtype=float)
46+
for i in range(n_data):
47+
dist[:, i: i + 1] = metric(data, data[i: i + 1], **kwargs)
48+
state.set_progress_value(100 * i / n_data)
49+
if state.is_interruption_requested():
50+
return None
51+
np.fill_diagonal(dist, 0)
52+
dist = Orange.misc.DistMatrix(dist, row_items=data, axis=1)
53+
if axis == 0: # columns
54+
dist = metric(data, **kwargs)
55+
return dist
4256

4357

4458
class OWDistances(OWWidget, ConcurrentWidgetMixin):

Orange/widgets/unsupervised/tests/test_owdistances.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import numpy as np
77

8-
from Orange.data import Table, Domain
98
from Orange import distance
9+
from Orange.data import Table, Domain
10+
from Orange.misc import DistMatrix
1011
from Orange.widgets.unsupervised.owdistances import OWDistances, METRICS, \
1112
DistanceRunner
1213
from Orange.widgets.tests.base import WidgetTest
@@ -17,28 +18,51 @@ class TestDistanceRunner(unittest.TestCase):
1718
def setUpClass(cls):
1819
super().setUpClass()
1920
cls.iris = Table("iris")[::5]
21+
cls.iris.X[0, 2] = np.nan
22+
cls.iris.X[1, 3] = np.nan
23+
cls.iris.X[2, 1] = np.nan
24+
cls.zoo = Table("zoo")[::5]
25+
cls.zoo.X[0, 2] = np.nan
26+
cls.zoo.X[1, 3] = np.nan
27+
cls.zoo.X[2, 1] = np.nan
2028

2129
def test_run(self):
22-
for _, metric in METRICS:
30+
state = Mock()
31+
state.is_interruption_requested = Mock(return_value=False)
32+
for name, metric in METRICS:
33+
data = self.iris
34+
if not metric.supports_missing:
35+
data = distance.impute(data)
36+
elif name == "Jaccard":
37+
data = self.zoo
38+
2339
# between rows, normalized
24-
dist1 = DistanceRunner.run(self.iris, metric, True, 0, Mock())
25-
dist2 = metric(self.iris, axis=1, impute=True, normalize=True)
26-
np.testing.assert_array_equal(dist1, dist2)
40+
dist1 = DistanceRunner.run(data, metric, True, 0, state)
41+
dist2 = metric(data, axis=1, impute=True, normalize=True)
42+
self.assertDistMatrixEqual(dist1, dist2)
2743

2844
# between rows, not normalized
29-
dist1 = DistanceRunner.run(self.iris, metric, False, 0, Mock())
30-
dist2 = metric(self.iris, axis=1, impute=True, normalize=False)
31-
np.testing.assert_array_equal(dist1, dist2)
45+
dist1 = DistanceRunner.run(data, metric, False, 0, state)
46+
dist2 = metric(data, axis=1, impute=True, normalize=False)
47+
self.assertDistMatrixEqual(dist1, dist2)
3248

3349
# between columns, normalized
34-
dist1 = DistanceRunner.run(self.iris, metric, True, 1, Mock())
35-
dist2 = metric(self.iris, axis=0, impute=True, normalize=True)
36-
np.testing.assert_array_equal(dist1, dist2)
50+
dist1 = DistanceRunner.run(data, metric, True, 1, state)
51+
dist2 = metric(data, axis=0, impute=True, normalize=True)
52+
self.assertDistMatrixEqual(dist1, dist2)
3753

3854
# between columns, not normalized
39-
dist1 = DistanceRunner.run(self.iris, metric, False, 1, Mock())
40-
dist2 = metric(self.iris, axis=0, impute=True, normalize=False)
41-
np.testing.assert_array_equal(dist1, dist2)
55+
dist1 = DistanceRunner.run(data, metric, False, 1, state)
56+
dist2 = metric(data, axis=0, impute=True, normalize=False)
57+
self.assertDistMatrixEqual(dist1, dist2)
58+
59+
def assertDistMatrixEqual(self, dist1, dist2):
60+
self.assertIsInstance(dist1, DistMatrix)
61+
self.assertIsInstance(dist2, DistMatrix)
62+
self.assertEqual(dist1.axis, dist2.axis)
63+
self.assertEqual(dist1.row_items, dist2.row_items)
64+
self.assertEqual(dist1.col_items, dist2.col_items)
65+
np.testing.assert_array_almost_equal(dist1, dist2)
4266

4367

4468
class TestOWDistances(WidgetTest):
@@ -59,14 +83,13 @@ def test_distance_combo(self):
5983
self.widget.metrics_combo.activated.emit(i)
6084
self.widget.metrics_combo.setCurrentIndex(i)
6185
self.wait_until_stop_blocking()
62-
self.send_signal(self.widget.Inputs.data, self.iris)
6386
if metric.supports_normalization:
6487
expected = metric(self.iris, normalize=self.widget.normalized_dist)
6588
else:
6689
expected = metric(self.iris)
6790

6891
if metric is not distance.Jaccard:
69-
np.testing.assert_array_equal(
92+
np.testing.assert_array_almost_equal(
7093
expected, self.get_output(self.widget.Outputs.distances))
7194

7295
def test_error_message(self):

0 commit comments

Comments
 (0)