Skip to content

Commit de3c95c

Browse files
committed
VizRankDialog: Use extended thread pool to prevent segfaults
When fed large datasets, correlations widget exited with segmentation fault, (probably) due to insufficient stack size for created task.
1 parent e19d857 commit de3c95c

File tree

4 files changed

+240
-79
lines changed

4 files changed

+240
-79
lines changed

Orange/widgets/data/tests/test_owcorrelations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,14 @@ def test_row_for_state(self):
308308
self.assertEqual(row[1].data(Qt.DisplayRole), self.attrs[0].name)
309309
self.assertEqual(row[2].data(Qt.DisplayRole), self.attrs[1].name)
310310

311+
def test_iterate_states(self):
312+
self.assertListEqual(list(self.vizrank.iterate_states(None)),
313+
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
314+
self.assertListEqual(list(self.vizrank.iterate_states((1, 0))),
315+
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
316+
self.assertListEqual(list(self.vizrank.iterate_states((2, 1))),
317+
[(2, 1), (3, 0), (3, 1), (3, 2)])
318+
311319
def test_iterate_states_by_feature(self):
312320
self.vizrank.sel_feature_index = 2
313321
states = self.vizrank.iterate_states_by_feature()
@@ -345,3 +353,7 @@ def test_get_states_one_cluster(self):
345353
states = set(heuristic.get_states(None))
346354
self.assertEqual(len(states), 1)
347355
self.assertSetEqual(states, {(0, 1)})
356+
357+
358+
if __name__ == "__main__":
359+
unittest.main()

Orange/widgets/visualize/tests/test_owlinearprojection.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from Orange.widgets.visualize.owlinearprojection import (
1616
OWLinearProjection, LinearProjectionVizRank
1717
)
18-
from Orange.widgets.visualize.utils import Worker
18+
from Orange.widgets.visualize.utils import run_vizrank
1919

2020

2121
class TestOWLinearProjection(WidgetTest, AnchorProjectionWidgetTestMixin,
@@ -205,16 +205,14 @@ def setUp(self):
205205

206206
def test_discrete_class(self):
207207
self.send_signal(self.widget.Inputs.data, self.data)
208-
worker = Worker(self.vizrank)
209-
self.vizrank.keep_running = True
210-
worker.do_work()
208+
run_vizrank(self.vizrank.compute_score,
209+
self.vizrank.iterate_states(None), [], Mock())
211210

212211
def test_continuous_class(self):
213212
data = Table("housing")[::100]
214213
self.send_signal(self.widget.Inputs.data, data)
215-
worker = Worker(self.vizrank)
216-
self.vizrank.keep_running = True
217-
worker.do_work()
214+
run_vizrank(self.vizrank.compute_score,
215+
self.vizrank.iterate_states(None), [], Mock())
218216

219217
def test_set_attrs(self):
220218
self.send_signal(self.widget.Inputs.data, self.data)
@@ -230,3 +228,8 @@ def test_set_attrs(self):
230228
self.assertNotEqual(self.widget.model_selected[:], model_selected)
231229
c2 = self.get_output(self.widget.Outputs.components)
232230
self.assertNotEqual(c1.domain.attributes, c2.domain.attributes)
231+
232+
233+
if __name__ == "__main__":
234+
import unittest
235+
unittest.main()
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from itertools import chain
2+
import unittest
3+
from unittest.mock import Mock
4+
from queue import Queue
5+
6+
from AnyQt.QtGui import QStandardItem
7+
8+
from Orange.data import Table
9+
from Orange.widgets.visualize.utils import (
10+
VizRankDialog, Result, run_vizrank, QueuedScore
11+
)
12+
from Orange.widgets.tests.base import WidgetTest
13+
14+
15+
def compute_score(x):
16+
return (x[0] + 1) / (x[1] + 1)
17+
18+
19+
class TestRunner(unittest.TestCase):
20+
@classmethod
21+
def setUpClass(cls):
22+
cls.data = Table("iris")
23+
24+
def test_Result(self):
25+
res = Result(queue=Queue(), scores=[])
26+
self.assertIsInstance(res.queue, Queue)
27+
self.assertIsInstance(res.scores, list)
28+
29+
def test_run_vizrank(self):
30+
scores, task = [], Mock()
31+
# run through all states
32+
task.is_interruption_requested.return_value = False
33+
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
34+
res = run_vizrank(compute_score, chain(states), scores, task)
35+
36+
next_state = self.assertQueueEqual(
37+
res.queue, [0, 0, 0, 3, 2, 5], compute_score,
38+
states, states[1:] + [None])
39+
self.assertIsNone(next_state)
40+
res_scores = sorted([compute_score(x) for x in states])
41+
self.assertListEqual(res.scores, res_scores)
42+
self.assertIsNot(scores, res.scores)
43+
self.assertEqual(task.set_partial_result.call_count, 6)
44+
45+
def test_run_vizrank_interrupt(self):
46+
scores, task = [], Mock()
47+
# interrupt calculation in third iteration
48+
task.is_interruption_requested.side_effect = lambda: \
49+
True if task.is_interruption_requested.call_count > 2 else False
50+
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
51+
res = run_vizrank(compute_score, chain(states), scores, task)
52+
53+
next_state = self.assertQueueEqual(
54+
res.queue, [0, 0], compute_score, states[:2], states[1:3])
55+
self.assertEqual(next_state, (0, 3))
56+
res_scores = sorted([compute_score(x) for x in states[:2]])
57+
self.assertListEqual(res.scores, res_scores)
58+
self.assertIsNot(scores, res.scores)
59+
self.assertEqual(task.set_partial_result.call_count, 2)
60+
61+
# continue calculation through all states
62+
task.is_interruption_requested.side_effect = lambda: False
63+
i = states.index(next_state)
64+
res = run_vizrank(compute_score, chain(states[i:]), res_scores, task)
65+
66+
next_state = self.assertQueueEqual(
67+
res.queue, [0, 3, 2, 5], compute_score, states[2:],
68+
states[3:] + [None])
69+
self.assertIsNone(next_state)
70+
res_scores = sorted([compute_score(x) for x in states])
71+
self.assertListEqual(res.scores, res_scores)
72+
self.assertIsNot(scores, res.scores)
73+
self.assertEqual(task.set_partial_result.call_count, 6)
74+
75+
def assertQueueEqual(self, queue, positions, f, states, next_states):
76+
self.assertIsInstance(queue, Queue)
77+
for qs in (QueuedScore(position=p, score=f(s), state=s, next_state=ns)
78+
for p, s, ns in zip(positions, states, next_states)):
79+
result = queue.get_nowait()
80+
self.assertEqual(result.position, qs.position)
81+
self.assertEqual(result.state, qs.state)
82+
self.assertEqual(result.next_state, qs.next_state)
83+
self.assertEqual(result.score, qs.score)
84+
next_state = result.next_state
85+
return next_state
86+
87+
88+
class TestVizRankDialog(WidgetTest):
89+
def test_on_partial_result(self):
90+
def iterate_states(initial_state):
91+
if initial_state is not None:
92+
return chain(states[states.index(initial_state):])
93+
return chain(states)
94+
95+
def invoke_on_partial_result():
96+
widget.on_partial_result(run_vizrank(
97+
widget.compute_score,
98+
widget.iterate_states(widget.saved_state),
99+
widget.scores, task
100+
))
101+
102+
task = Mock()
103+
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
104+
105+
widget = VizRankDialog(None)
106+
widget.progressBarInit()
107+
widget.compute_score = compute_score
108+
widget.iterate_states = iterate_states
109+
widget.row_for_state = lambda sc, _: [QStandardItem(str(sc))]
110+
111+
# interrupt calculation in third iteration
112+
task.is_interruption_requested.side_effect = lambda: \
113+
True if task.is_interruption_requested.call_count > 2 else False
114+
invoke_on_partial_result()
115+
self.assertEqual(widget.rank_model.rowCount(), 2)
116+
for row, score in enumerate(
117+
sorted([compute_score(x) for x in states[:2]])):
118+
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
119+
self.assertEqual(widget.saved_progress, 2)
120+
121+
# continue calculation through all states
122+
task.is_interruption_requested.side_effect = lambda: False
123+
invoke_on_partial_result()
124+
self.assertEqual(widget.rank_model.rowCount(), 6)
125+
for row, score in enumerate(
126+
sorted([compute_score(x) for x in states])):
127+
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
128+
self.assertEqual(widget.saved_progress, 6)
129+
130+
131+
if __name__ == "__main__":
132+
unittest.main()

0 commit comments

Comments
 (0)