Skip to content

Commit 3217a99

Browse files
committed
Refactor start()
1 parent 2768723 commit 3217a99

File tree

3 files changed

+45
-57
lines changed

3 files changed

+45
-57
lines changed

Orange/widgets/utils/concurrent.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ def is_interruption_requested(self) -> bool:
832832
return self.__interruption_requested
833833

834834
def start(self, executor: concurrent.futures.Executor,
835-
func: Callable[[], Any]=None) -> Future:
835+
func: Callable[[], Any] = None) -> Future:
836836
assert self.future is None
837837
assert not self.__interruption_requested
838838
self.__future = executor.submit(func)
@@ -863,15 +863,11 @@ def __init__(self):
863863
def task(self) -> TaskState:
864864
return self.__task
865865

866-
def _prepare_task(self, state: TaskState) -> Callable[[], Any]:
867-
raise NotImplementedError
868-
869866
def _on_partial_result(self, result: Any) -> None:
870867
raise NotImplementedError
871868

872869
def _on_done(self, result: Any) -> None:
873-
# NOTE: All of these have already been set by _on_partial_result,
874-
# we double check that they are aliases
870+
""" Invoked when task is done. """
875871
raise NotImplementedError
876872

877873
def __set_state_ready(self):
@@ -893,7 +889,7 @@ def __start_task(self, task: Callable[[], Any], state: TaskState):
893889
state.setParent(self)
894890
self.__task = state
895891

896-
def __cancel_task(self, wait: bool=True):
892+
def __cancel_task(self, wait: bool = True):
897893
if self.__task is not None:
898894
state, self.__task = self.__task, None
899895
state.cancel()
@@ -909,7 +905,6 @@ def __cancel_task(self, wait: bool=True):
909905
w.done.connect(state.deleteLater)
910906

911907
def __on_task_done(self, future: Future):
912-
""" Invoked when task is done. """
913908
assert future.done()
914909
assert self.__task is not None
915910
assert self.__task.future is future
@@ -920,16 +915,18 @@ def __on_task_done(self, future: Future):
920915
result = future.result()
921916
self._on_done(result)
922917

923-
def start(self):
918+
def start(self, task: Callable, *args):
924919
""" Call to start the task. """
925920
self.__cancel_task(wait=False)
926921

927922
if self.data is None:
928923
self.__set_state_ready()
929924
return
930925

926+
assert callable(task), "`task` must be callable!"
931927
state = TaskState(self)
932-
task = self._prepare_task(state)
928+
task = partial(task, *(args + (state,)))
929+
933930
self.__set_state_busy()
934931
self.__start_task(task, state)
935932

Orange/widgets/utils/tests/concurrent_example.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# pylint: disable=too-many-ancestors
22
from typing import Optional
33
from types import SimpleNamespace as namespace
4-
from functools import partial
54

65
import numpy as np
76

@@ -28,8 +27,6 @@ def run(data: Table, embedding: Optional[np.ndarray], state: TaskState):
2827
x_data = np.array(np.mean(data.X, axis=1))
2928
if x_data.ndim == 2:
3029
x_data = x_data.ravel()
31-
y_data = np.ones(len(x_data))
32-
y_data[::2] = step % 2
3330
y_data = np.random.rand(len(x_data))
3431
# Needs a copy because projection should not be modified
3532
# inplace. If it is modified inplace, the widget and the thread
@@ -70,7 +67,7 @@ def _add_controls(self):
7067
super()._add_controls()
7168

7269
def __param_combo_changed(self):
73-
super().start()
70+
self._run()
7471

7572
def _toggle_run(self):
7673
if self.data is None:
@@ -83,12 +80,13 @@ def _toggle_run(self):
8380
self.commit()
8481
# Resume task
8582
else:
86-
self.start()
83+
self._run()
8784

88-
# extend ConcurrentWidgetMixin
89-
def _prepare_task(self, state: TaskState):
90-
return partial(run, self.data, embedding=self.embedding, state=state)
85+
def _run(self):
86+
self.run_button.setText("Stop")
87+
self.start(run, self.data, self.embedding)
9188

89+
# ConcurrentWidgetMixin
9290
def _on_partial_result(self, result: Result):
9391
assert isinstance(result.embedding, np.ndarray)
9492
assert len(result.embedding) == len(self.data)
@@ -101,21 +99,19 @@ def _on_partial_result(self, result: Result):
10199
self.graph.update_density()
102100

103101
def _on_done(self, result: Result):
102+
# NOTE: All of these have already been set by _on_partial_result,
103+
# we double check that they are aliases
104104
assert isinstance(result.embedding, np.ndarray)
105105
assert len(result.embedding) == len(self.data)
106106
self.embedding = result.embedding
107107
self.run_button.setText("Start")
108108
self.commit()
109109

110-
def start(self):
111-
self.run_button.setText("Stop")
112-
super().start()
113-
114-
# extend OWDataProjectionWidget
110+
# OWDataProjectionWidget
115111
def set_data(self, data: Table):
116112
super().set_data(data)
117113
if self._invalidated:
118-
self.start()
114+
self._run()
119115

120116
def get_embedding(self):
121117
if self.embedding is None:

Orange/widgets/visualize/owfreeviz.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# pylint: disable=too-many-ancestors
22
from enum import IntEnum
33
from types import SimpleNamespace as namespace
4-
from functools import partial
54

65
import numpy as np
76

@@ -187,7 +186,7 @@ def __init_combo_changed(self):
187186
self.setup_plot()
188187
self.commit()
189188
if self.task is not None:
190-
super().start()
189+
self._run()
191190

192191
def _toggle_run(self):
193192
if self.data is None:
@@ -199,15 +198,15 @@ def _toggle_run(self):
199198
self.run_button.setText("Resume")
200199
self.commit()
201200
else:
202-
self.start()
201+
self._run()
203202

204-
# extend ConcurrentWidgetMixin
205-
def _prepare_task(self, state: TaskState):
206-
return partial(run_freeviz, self.effective_data,
207-
projector=self.projector,
208-
projection=self.projection,
209-
state=state)
203+
def _run(self):
204+
self.graph.set_sample_size(self.SAMPLE_SIZE)
205+
self.run_button.setText("Stop")
206+
self.start(run_freeviz, self.effective_data,
207+
self.projector, self.projection)
210208

209+
# ConcurrentWidgetMixin
211210
def _on_partial_result(self, result: Result):
212211
assert isinstance(result.projector, FreeViz)
213212
assert isinstance(result.projection, FreeVizModel)
@@ -225,10 +224,26 @@ def _on_done(self, results: Result):
225224
self.run_button.setText("Start")
226225
self.commit()
227226

228-
def start(self):
229-
self.graph.set_sample_size(self.SAMPLE_SIZE)
230-
self.run_button.setText("Stop")
231-
super().start()
227+
# OWAnchorProjectionWidget
228+
def set_data(self, data):
229+
super().set_data(data)
230+
if self._invalidated:
231+
self.init_projection()
232+
233+
def init_projection(self):
234+
if self.data is None:
235+
return
236+
anchors = FreeViz.init_radial(len(self.effective_variables)) \
237+
if self.initialization == InitType.Circular \
238+
else FreeViz.init_random(len(self.effective_variables), 2)
239+
self.projector = FreeViz(scale=False, center=False,
240+
initial=anchors, maxiter=10)
241+
data = self.projector.preprocess(self.effective_data)
242+
self.projector.domain = data.domain
243+
self.projector.components_ = anchors.T
244+
self.projection = FreeVizModel(self.projector, self.projector.domain, 2)
245+
self.projection.pre_domain = data.domain
246+
self.projection.name = self.projector.name
232247

233248
def check_data(self):
234249
def error(err):
@@ -255,26 +270,6 @@ def error(err):
255270
self.Warning.removed_features()
256271
self.run_button.setEnabled(self.data is not None)
257272

258-
def set_data(self, data):
259-
super().set_data(data)
260-
if self._invalidated:
261-
self.init_projection()
262-
263-
def init_projection(self):
264-
if self.data is None:
265-
return
266-
anchors = FreeViz.init_radial(len(self.effective_variables)) \
267-
if self.initialization == InitType.Circular \
268-
else FreeViz.init_random(len(self.effective_variables), 2)
269-
self.projector = FreeViz(scale=False, center=False,
270-
initial=anchors, maxiter=10)
271-
data = self.projector.preprocess(self.effective_data)
272-
self.projector.domain = data.domain
273-
self.projector.components_ = anchors.T
274-
self.projection = FreeVizModel(self.projector, self.projector.domain, 2)
275-
self.projection.pre_domain = data.domain
276-
self.projection.name = self.projector.name
277-
278273
def get_coordinates_data(self):
279274
embedding = self.get_embedding()
280275
if embedding is None:

0 commit comments

Comments
 (0)