Skip to content

Commit 9e0d4eb

Browse files
authored
Merge pull request #3669 from VesnaT/vizrank_threads
[FIX] VizRankDialog: Use extended thread pool to prevent segfaults
2 parents 2125acd + 5e14813 commit 9e0d4eb

File tree

5 files changed

+251
-89
lines changed

5 files changed

+251
-89
lines changed

Orange/widgets/data/owcorrelations.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class KMeansCorrelationHeuristic:
5959
def __init__(self, data):
6060
self.n_attributes = len(data.domain.attributes)
6161
self.data = data
62-
self.states = None
62+
self.clusters = None
6363
self.n_clusters = int(np.sqrt(self.n_attributes))
6464

6565
def get_clusters_of_attributes(self):
@@ -84,16 +84,15 @@ def get_states(self, initial_state):
8484
:param initial_state: initial state; None if this is the first call
8585
:return: generator of tuples of states
8686
"""
87-
if self.states is not None:
88-
return chain([initial_state], self.states)
89-
90-
clusters = self.get_clusters_of_attributes()
87+
if self.clusters is None:
88+
self.clusters = self.get_clusters_of_attributes()
89+
clusters = self.clusters
9190

9291
# combinations within clusters
93-
self.states = chain.from_iterable(combinations(cluster.instances, 2)
94-
for cluster in clusters)
92+
states0 = chain.from_iterable(combinations(cluster.instances, 2)
93+
for cluster in clusters)
9594
if self.n_clusters == 1:
96-
return self.states
95+
return states0
9796

9897
# combinations among clusters - closest clusters first
9998
centroids = [c.centroid for c in clusters]
@@ -104,8 +103,13 @@ def get_states(self, initial_state):
104103
states = ((min((c1, c2)), max((c1, c2))) for i in np.argsort(distances)
105104
for c1 in clusters[cluster_combs[i][0]].instances
106105
for c2 in clusters[cluster_combs[i][1]].instances)
107-
self.states = chain(self.states, states)
108-
return self.states
106+
states = chain(states0, states)
107+
108+
if initial_state is not None:
109+
while next(states) != initial_state:
110+
pass
111+
return chain([initial_state], states)
112+
return states
109113

110114

111115
class CorrelationRank(VizRankDialogAttrPair):

Orange/widgets/data/tests/test_owcorrelations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,14 @@ def test_row_for_state(self):
308308
self.assertEqual(row[1].data(Qt.DisplayRole), self.attrs[0].name)
309309
self.assertEqual(row[2].data(Qt.DisplayRole), self.attrs[1].name)
310310

311+
def test_iterate_states(self):
312+
self.assertListEqual(list(self.vizrank.iterate_states(None)),
313+
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
314+
self.assertListEqual(list(self.vizrank.iterate_states((1, 0))),
315+
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
316+
self.assertListEqual(list(self.vizrank.iterate_states((2, 1))),
317+
[(2, 1), (3, 0), (3, 1), (3, 2)])
318+
311319
def test_iterate_states_by_feature(self):
312320
self.vizrank.sel_feature_index = 2
313321
states = self.vizrank.iterate_states_by_feature()
@@ -345,3 +353,7 @@ def test_get_states_one_cluster(self):
345353
states = set(heuristic.get_states(None))
346354
self.assertEqual(len(states), 1)
347355
self.assertSetEqual(states, {(0, 1)})
356+
357+
358+
if __name__ == "__main__":
359+
unittest.main()

Orange/widgets/visualize/tests/test_owlinearprojection.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from Orange.widgets.visualize.owlinearprojection import (
1616
OWLinearProjection, LinearProjectionVizRank
1717
)
18-
from Orange.widgets.visualize.utils import Worker
18+
from Orange.widgets.visualize.utils import run_vizrank
1919

2020

2121
class TestOWLinearProjection(WidgetTest, AnchorProjectionWidgetTestMixin,
@@ -205,16 +205,14 @@ def setUp(self):
205205

206206
def test_discrete_class(self):
207207
self.send_signal(self.widget.Inputs.data, self.data)
208-
worker = Worker(self.vizrank)
209-
self.vizrank.keep_running = True
210-
worker.do_work()
208+
run_vizrank(self.vizrank.compute_score,
209+
self.vizrank.iterate_states(None), [], Mock())
211210

212211
def test_continuous_class(self):
213212
data = Table("housing")[::100]
214213
self.send_signal(self.widget.Inputs.data, data)
215-
worker = Worker(self.vizrank)
216-
self.vizrank.keep_running = True
217-
worker.do_work()
214+
run_vizrank(self.vizrank.compute_score,
215+
self.vizrank.iterate_states(None), [], Mock())
218216

219217
def test_set_attrs(self):
220218
self.send_signal(self.widget.Inputs.data, self.data)
@@ -230,3 +228,8 @@ def test_set_attrs(self):
230228
self.assertNotEqual(self.widget.model_selected[:], model_selected)
231229
c2 = self.get_output(self.widget.Outputs.components)
232230
self.assertNotEqual(c1.domain.attributes, c2.domain.attributes)
231+
232+
233+
if __name__ == "__main__":
234+
import unittest
235+
unittest.main()
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from itertools import chain
2+
import unittest
3+
from unittest.mock import Mock
4+
from queue import Queue
5+
6+
from AnyQt.QtGui import QStandardItem
7+
8+
from Orange.data import Table
9+
from Orange.widgets.visualize.utils import (
10+
VizRankDialog, Result, run_vizrank, QueuedScore
11+
)
12+
from Orange.widgets.tests.base import WidgetTest
13+
14+
15+
def compute_score(x):
16+
return (x[0] + 1) / (x[1] + 1)
17+
18+
19+
class TestRunner(unittest.TestCase):
20+
@classmethod
21+
def setUpClass(cls):
22+
cls.data = Table("iris")
23+
24+
def test_Result(self):
25+
res = Result(queue=Queue(), scores=[])
26+
self.assertIsInstance(res.queue, Queue)
27+
self.assertIsInstance(res.scores, list)
28+
29+
def test_run_vizrank(self):
30+
scores, task = [], Mock()
31+
# run through all states
32+
task.is_interruption_requested.return_value = False
33+
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
34+
res = run_vizrank(compute_score, chain(states), scores, task)
35+
36+
next_state = self.assertQueueEqual(
37+
res.queue, [0, 0, 0, 3, 2, 5], compute_score,
38+
states, states[1:] + [None])
39+
self.assertIsNone(next_state)
40+
res_scores = sorted([compute_score(x) for x in states])
41+
self.assertListEqual(res.scores, res_scores)
42+
self.assertIsNot(scores, res.scores)
43+
self.assertEqual(task.set_partial_result.call_count, 6)
44+
45+
def test_run_vizrank_interrupt(self):
46+
scores, task = [], Mock()
47+
# interrupt calculation in third iteration
48+
task.is_interruption_requested.side_effect = lambda: \
49+
True if task.is_interruption_requested.call_count > 2 else False
50+
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
51+
res = run_vizrank(compute_score, chain(states), scores, task)
52+
53+
next_state = self.assertQueueEqual(
54+
res.queue, [0, 0], compute_score, states[:2], states[1:3])
55+
self.assertEqual(next_state, (0, 3))
56+
res_scores = sorted([compute_score(x) for x in states[:2]])
57+
self.assertListEqual(res.scores, res_scores)
58+
self.assertIsNot(scores, res.scores)
59+
self.assertEqual(task.set_partial_result.call_count, 2)
60+
61+
# continue calculation through all states
62+
task.is_interruption_requested.side_effect = lambda: False
63+
i = states.index(next_state)
64+
res = run_vizrank(compute_score, chain(states[i:]), res_scores, task)
65+
66+
next_state = self.assertQueueEqual(
67+
res.queue, [0, 3, 2, 5], compute_score, states[2:],
68+
states[3:] + [None])
69+
self.assertIsNone(next_state)
70+
res_scores = sorted([compute_score(x) for x in states])
71+
self.assertListEqual(res.scores, res_scores)
72+
self.assertIsNot(scores, res.scores)
73+
self.assertEqual(task.set_partial_result.call_count, 6)
74+
75+
def assertQueueEqual(self, queue, positions, f, states, next_states):
76+
self.assertIsInstance(queue, Queue)
77+
for qs in (QueuedScore(position=p, score=f(s), state=s, next_state=ns)
78+
for p, s, ns in zip(positions, states, next_states)):
79+
result = queue.get_nowait()
80+
self.assertEqual(result.position, qs.position)
81+
self.assertEqual(result.state, qs.state)
82+
self.assertEqual(result.next_state, qs.next_state)
83+
self.assertEqual(result.score, qs.score)
84+
next_state = result.next_state
85+
return next_state
86+
87+
88+
class TestVizRankDialog(WidgetTest):
89+
def test_on_partial_result(self):
90+
def iterate_states(initial_state):
91+
if initial_state is not None:
92+
return chain(states[states.index(initial_state):])
93+
return chain(states)
94+
95+
def invoke_on_partial_result():
96+
widget.on_partial_result(run_vizrank(
97+
widget.compute_score,
98+
widget.iterate_states(widget.saved_state),
99+
widget.scores, task
100+
))
101+
102+
task = Mock()
103+
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
104+
105+
widget = VizRankDialog(None)
106+
widget.progressBarInit()
107+
widget.compute_score = compute_score
108+
widget.iterate_states = iterate_states
109+
widget.row_for_state = lambda sc, _: [QStandardItem(str(sc))]
110+
111+
# interrupt calculation in third iteration
112+
task.is_interruption_requested.side_effect = lambda: \
113+
True if task.is_interruption_requested.call_count > 2 else False
114+
invoke_on_partial_result()
115+
self.assertEqual(widget.rank_model.rowCount(), 2)
116+
for row, score in enumerate(
117+
sorted([compute_score(x) for x in states[:2]])):
118+
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
119+
self.assertEqual(widget.saved_progress, 2)
120+
121+
# continue calculation through all states
122+
task.is_interruption_requested.side_effect = lambda: False
123+
invoke_on_partial_result()
124+
self.assertEqual(widget.rank_model.rowCount(), 6)
125+
for row, score in enumerate(
126+
sorted([compute_score(x) for x in states])):
127+
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
128+
self.assertEqual(widget.saved_progress, 6)
129+
130+
131+
if __name__ == "__main__":
132+
unittest.main()

0 commit comments

Comments
 (0)