Skip to content

Commit ee8a726

Browse files
committed
Refactor
1 parent 12d4415 commit ee8a726

File tree

5 files changed

+184
-189
lines changed

5 files changed

+184
-189
lines changed

Orange/widgets/tests/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,12 +1018,16 @@ def test_invalidated_embedding(self, timeout=DEFAULT_TIMEOUT):
10181018

10191019
def test_saved_selection(self, timeout=DEFAULT_TIMEOUT):
10201020
self.send_signal(self.widget.Inputs.data, self.data)
1021-
self.wait_until_stop_blocking()
1021+
if self.widget.isBlocking():
1022+
spy = QSignalSpy(self.widget.blockingStateChanged)
1023+
self.assertTrue(spy.wait(timeout))
10221024
self.widget.graph.select_by_indices(list(range(0, len(self.data), 10)))
10231025
settings = self.widget.settingsHandler.pack_data(self.widget)
10241026
w = self.create_widget(self.widget.__class__, stored_settings=settings)
10251027
self.send_signal(self.widget.Inputs.data, self.data, widget=w)
1026-
self.wait_until_stop_blocking(widget=w)
1028+
if w.isBlocking():
1029+
spy = QSignalSpy(w.blockingStateChanged)
1030+
self.assertTrue(spy.wait(timeout))
10271031
self.assertEqual(np.sum(w.graph.selection), 15)
10281032
np.testing.assert_equal(self.widget.graph.selection, w.graph.selection)
10291033

Orange/widgets/utils/concurrent.py

Lines changed: 58 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -813,34 +813,33 @@ def __init__(self, *args):
813813
self.partial_result_ready, Qt.QueuedConnection)
814814

815815
@property
816-
def future(self):
817-
# type: () -> Future
816+
def future(self) -> Future:
818817
return self.__future
819818

820-
def set_status(self, text):
819+
def set_status(self, text: str):
821820
self._p_status_changed.emit(text)
822821

823-
def set_progress_value(self, value):
822+
def set_progress_value(self, value: float):
824823
if round(value, 1) > round(self.__progress, 1):
825824
# Only emit progress when it has changed sufficiently
826825
self._p_progress_changed.emit(value)
827826
self.__progress = value
828827

829-
def set_partial_results(self, value):
828+
def set_partial_result(self, value: Any):
830829
self._p_partial_result_ready.emit(value)
831830

832-
def is_interruption_requested(self):
831+
def is_interruption_requested(self) -> bool:
833832
return self.__interruption_requested
834833

835-
def start(self, executor, func=None):
836-
# type: (concurrent.futures.Executor, Callable[[], Any]) -> Future
834+
def start(self, executor: concurrent.futures.Executor,
835+
func: Callable[[], Any] = None) -> Future:
837836
assert self.future is None
838837
assert not self.__interruption_requested
839838
self.__future = executor.submit(func)
840839
self.watcher.setFuture(self.future)
841840
return self.future
842841

843-
def cancel(self):
842+
def cancel(self) -> bool:
844843
assert not self.__interruption_requested
845844
self.__interruption_requested = True
846845
if self.future is not None:
@@ -861,22 +860,55 @@ def __init__(self):
861860
self.__task = None # type: Optional[TaskState]
862861

863862
@property
864-
def task(self):
863+
def task(self) -> TaskState:
865864
return self.__task
866865

867-
def _prepare_task(self, state: TaskState) -> Callable[[], Any]:
866+
def on_partial_result(self, result: Any) -> None:
867+
""" Invoked from runner (by state) to send the partial results
868+
The method should handle partial results, i.e. show them in the plot.
869+
870+
:param result: any data structure to hold temporary result
871+
"""
868872
raise NotImplementedError
869873

870-
def _set_partial_results(self, result: Any) -> None:
874+
def on_done(self, result: Any) -> None:
875+
""" Invoked when task is done.
876+
The method should re-set the result (to double check it) and
877+
perform operations with obtained results, eg. send data to the output.
878+
879+
:param result: any data structure to hold temporary result
880+
"""
871881
raise NotImplementedError
872882

873-
def _set_results(self, results: Any) -> None:
874-
# NOTE: All of these have already been set by _set_partial_results,
875-
# we double check that they are aliases
876-
for key in results.__dict__:
877-
value = getattr(results, key)
878-
if value is not None:
879-
setattr(self, key, value)
883+
def start(self, task: Callable, *args, **kwargs):
884+
""" Call from derived class to start the task.
885+
:param task: runner - a method to run in a thread - should accept
886+
`state` parameter
887+
"""
888+
self.__cancel_task(wait=False)
889+
890+
if self.data is None:
891+
self.__set_state_ready()
892+
return
893+
894+
assert callable(task), "`task` must be callable!"
895+
state = TaskState(self)
896+
task = partial(task, *(args + (state,)), **kwargs)
897+
898+
self.__set_state_busy()
899+
self.__start_task(task, state)
900+
901+
def cancel(self):
902+
""" Call from derived class to stop the task. """
903+
self.__cancel_task(wait=False)
904+
self.__set_state_ready()
905+
906+
def shutdown(self):
907+
""" Call from derived class when the widget is deleted
908+
(in onDeleteWidget).
909+
"""
910+
self.__cancel_task(wait=True)
911+
self.__executor.shutdown(True)
880912

881913
def __set_state_ready(self):
882914
self.progressBarFinished()
@@ -891,42 +923,28 @@ def __start_task(self, task: Callable[[], Any], state: TaskState):
891923
assert self.__task is None
892924
state.status_changed.connect(self.setStatusMessage)
893925
state.progress_changed.connect(self.progressBarSet)
894-
state.partial_result_ready.connect(self._set_partial_results)
895-
state.watcher.done.connect(self.on_done)
926+
state.partial_result_ready.connect(self.on_partial_result)
927+
state.watcher.done.connect(self.__on_task_done)
896928
state.start(self.__executor, task)
897929
state.setParent(self)
898930
self.__task = state
899931

900-
def __cancel_task(self, wait=True):
932+
def __cancel_task(self, wait: bool = True):
901933
if self.__task is not None:
902934
state, self.__task = self.__task, None
903935
state.cancel()
904-
state.partial_result_ready.disconnect(self._set_partial_results)
936+
state.partial_result_ready.disconnect(self.on_partial_result)
905937
state.status_changed.disconnect(self.setStatusMessage)
906938
state.progress_changed.disconnect(self.progressBarSet)
907-
state.watcher.done.disconnect(self.on_done)
939+
state.watcher.done.disconnect(self.__on_task_done)
908940
if wait:
909941
concurrent.futures.wait([state.future])
910942
state.deleteLater()
911943
else:
912944
w = FutureWatcher(state.future, parent=state)
913945
w.done.connect(state.deleteLater)
914946

915-
def start(self):
916-
""" Call to start the task. """
917-
self.__cancel_task(wait=False)
918-
919-
if self.data is None:
920-
self.__set_state_ready()
921-
return
922-
923-
state = TaskState(self)
924-
task = self._prepare_task(state)
925-
self.__set_state_busy()
926-
self.__start_task(task, state)
927-
928-
def on_done(self, future: Future):
929-
""" Invoked when task is done. """
947+
def __on_task_done(self, future: Future):
930948
assert future.done()
931949
assert self.__task is not None
932950
assert self.__task.future is future
@@ -935,14 +953,4 @@ def on_done(self, future: Future):
935953
task.deleteLater()
936954
self.__set_state_ready()
937955
result = future.result()
938-
self._set_results(result)
939-
940-
def cancel(self):
941-
""" Call to stop the task. """
942-
self.__cancel_task(wait=False)
943-
self.__set_state_ready()
944-
945-
def shutdown(self):
946-
""" Call when widget is deleted (in onDeleteWidget). """
947-
self.__cancel_task(wait=True)
948-
self.__executor.shutdown(True)
956+
self.on_done(result)

Orange/widgets/utils/tests/concurrent_example.py

Lines changed: 42 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# pylint: disable=too-many-ancestors
22
from typing import Optional
33
from types import SimpleNamespace as namespace
4-
from functools import partial
54

65
import numpy as np
76

@@ -13,42 +12,38 @@
1312
from Orange.widgets.visualize.utils.widget import OWDataProjectionWidget
1413

1514

16-
class Results(namespace):
15+
class Result(namespace):
1716
embedding = None # type: Optional[np.ndarray]
1817

1918

20-
class Runner:
21-
@staticmethod
22-
def run(data: Table, embedding: Optional[np.ndarray], state: TaskState):
23-
res = Results(embedding=embedding)
24-
25-
# simulate wasteful calculation (increase 'steps')
26-
step, steps = 0, 10
27-
state.set_status("Calculating...")
28-
while step < steps:
29-
for _ in range(steps):
30-
x_data = np.array(np.mean(data.X, axis=1))
31-
if x_data.ndim == 2:
32-
x_data = x_data.ravel()
33-
y_data = np.ones(len(x_data))
34-
y_data[::2] = step % 2
35-
y_data = np.random.rand(len(x_data))
36-
# Needs a copy because projection should not be modified
37-
# inplace. If it is modified inplace, the widget and the thread
38-
# hold a reference to the same object. When the thread is
39-
# interrupted it is still modifying the object, but the widget
40-
# receives it (the modified object) with a delay.
41-
embedding = np.vstack((x_data, y_data)).T.copy()
42-
step += 1
43-
if step % (steps / 10) == 0:
44-
state.set_progress_value(100 * step / steps)
45-
46-
if state.is_interruption_requested():
47-
return res
48-
49-
res.embedding = embedding
50-
state.set_partial_results(res)
51-
return res
19+
def run(data: Table, embedding: Optional[np.ndarray], state: TaskState):
20+
res = Result(embedding=embedding)
21+
22+
# simulate wasteful calculation (increase 'steps')
23+
step, steps = 0, 10
24+
state.set_status("Calculating...")
25+
while step < steps:
26+
for _ in range(steps):
27+
x_data = np.array(np.mean(data.X, axis=1))
28+
if x_data.ndim == 2:
29+
x_data = x_data.ravel()
30+
y_data = np.random.rand(len(x_data))
31+
# Needs a copy because projection should not be modified
32+
# inplace. If it is modified inplace, the widget and the thread
33+
# hold a reference to the same object. When the thread is
34+
# interrupted it is still modifying the object, but the widget
35+
# receives it (the modified object) with a delay.
36+
embedding = np.vstack((x_data, y_data)).T.copy()
37+
step += 1
38+
if step % (steps / 10) == 0:
39+
state.set_progress_value(100 * step / steps)
40+
41+
if state.is_interruption_requested():
42+
return res
43+
44+
res.embedding = embedding
45+
state.set_partial_result(res)
46+
return res
5247

5348

5449
class OWConcurrentWidget(OWDataProjectionWidget, ConcurrentWidgetMixin):
@@ -72,7 +67,7 @@ def _add_controls(self):
7267
super()._add_controls()
7368

7469
def __param_combo_changed(self):
75-
super().start()
70+
self._run()
7671

7772
def _toggle_run(self):
7873
if self.data is None:
@@ -85,14 +80,14 @@ def _toggle_run(self):
8580
self.commit()
8681
# Resume task
8782
else:
88-
self.start()
83+
self._run()
8984

90-
# extend ConcurrentWidgetMixin
91-
def _prepare_task(self, state: TaskState):
92-
return partial(Runner.run, self.data,
93-
embedding=self.embedding, state=state)
85+
def _run(self):
86+
self.run_button.setText("Stop")
87+
self.start(run, self.data, self.embedding)
9488

95-
def _set_partial_results(self, result: Results):
89+
# ConcurrentWidgetMixin
90+
def on_partial_result(self, result: Result):
9691
assert isinstance(result.embedding, np.ndarray)
9792
assert len(result.embedding) == len(self.data)
9893
first_result = self.embedding is None
@@ -103,25 +98,20 @@ def _set_partial_results(self, result: Results):
10398
self.graph.update_coordinates()
10499
self.graph.update_density()
105100

106-
def _set_results(self, result: Results):
101+
def on_done(self, result: Result):
102+
# NOTE: All of these have already been set by on_partial_result,
103+
# we double check that they are aliased
107104
assert isinstance(result.embedding, np.ndarray)
108105
assert len(result.embedding) == len(self.data)
109-
super()._set_results(result)
110-
111-
def start(self):
112-
self.run_button.setText("Stop")
113-
super().start()
114-
115-
def on_done(self, future):
116-
super().on_done(future)
106+
self.embedding = result.embedding
117107
self.run_button.setText("Start")
118108
self.commit()
119109

120-
# extend OWDataProjectionWidget
110+
# OWDataProjectionWidget
121111
def set_data(self, data: Table):
122112
super().set_data(data)
123113
if self._invalidated:
124-
self.start()
114+
self._run()
125115

126116
def get_embedding(self):
127117
if self.embedding is None:

0 commit comments

Comments
 (0)