Skip to content

Commit 209cfbe

Browse files
authored
Merge pull request #2958 from markotoplak/nn_thread
[ENH] Neural network widget that works in a separate thread
2 parents 9c4bd23 + 48bb3bc commit 209cfbe

File tree

6 files changed

+226
-5
lines changed

6 files changed

+226
-5
lines changed

Orange/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,11 @@ def __call__(self, data):
380380
m.params = self.params
381381
return m
382382

383+
def _initialize_wrapped(self):
384+
return self.__wraps__(**self.params)
385+
383386
def fit(self, X, Y, W=None):
384-
clf = self.__wraps__(**self.params)
387+
clf = self._initialize_wrapped()
385388
Y = Y.reshape(-1)
386389
if W is None or not self.supports_weights:
387390
return self.__returns__(clf.fit(X, Y))

Orange/classification/neural_network.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,28 @@
55
__all__ = ["NNClassificationLearner"]
66

77

8+
class NIterCallbackMixin:
9+
orange_callback = None
10+
11+
@property
12+
def n_iter_(self):
13+
return self.__orange_n_iter
14+
15+
@n_iter_.setter
16+
def n_iter_(self, v):
17+
self.__orange_n_iter = v
18+
if self.orange_callback:
19+
self.orange_callback(v)
20+
21+
22+
class MLPClassifierWCallback(skl_nn.MLPClassifier, NIterCallbackMixin):
23+
pass
24+
25+
826
class NNClassificationLearner(NNBase, SklLearner):
9-
__wraps__ = skl_nn.MLPClassifier
27+
__wraps__ = MLPClassifierWCallback
28+
29+
def _initialize_wrapped(self):
30+
clf = SklLearner._initialize_wrapped(self)
31+
clf.orange_callback = getattr(self, "callback", None)
32+
return clf

Orange/modelling/neural_network.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@
88
class NNLearner(SklFitter):
99
__fits__ = {'classification': NNClassificationLearner,
1010
'regression': NNRegressionLearner}
11+
12+
callback = None
13+
14+
def get_learner(self, problem_type):
15+
learner = super().get_learner(problem_type)
16+
learner.callback = self.callback
17+
return learner
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
import sklearn.neural_network as skl_nn
22
from Orange.base import NNBase
33
from Orange.regression import SklLearner
4+
from Orange.classification.neural_network import NIterCallbackMixin
45

56
__all__ = ["NNRegressionLearner"]
67

78

9+
class MLPRegressorWCallback(skl_nn.MLPRegressor, NIterCallbackMixin):
10+
pass
11+
12+
813
class NNRegressionLearner(NNBase, SklLearner):
9-
__wraps__ = skl_nn.MLPRegressor
14+
__wraps__ = MLPRegressorWCallback
15+
16+
def _initialize_wrapped(self):
17+
clf = SklLearner._initialize_wrapped(self)
18+
clf.orange_callback = getattr(self, "callback", None)
19+
return clf

Orange/widgets/model/owneuralnetwork.py

Lines changed: 164 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,62 @@
1+
from functools import partial
2+
import copy
3+
import logging
14
import re
25
import sys
6+
import concurrent.futures
37

48
from AnyQt.QtWidgets import QApplication
5-
from AnyQt.QtCore import Qt
9+
from AnyQt.QtCore import Qt, QThread, QObject
10+
from AnyQt.QtCore import pyqtSlot as Slot, pyqtSignal as Signal
611

712
from Orange.data import Table
813
from Orange.modelling import NNLearner
914
from Orange.widgets import gui
1015
from Orange.widgets.settings import Setting
1116
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
1217

18+
from Orange.widgets.utils.concurrent import ThreadExecutor, FutureWatcher
19+
20+
21+
class Task(QObject):
22+
"""
23+
A class that will hold the state for an learner evaluation.
24+
"""
25+
done = Signal(object)
26+
progressChanged = Signal(float)
27+
28+
future = None # type: concurrent.futures.Future
29+
watcher = None # type: FutureWatcher
30+
cancelled = False # type: bool
31+
32+
def setFuture(self, future):
33+
if self.future is not None:
34+
raise RuntimeError("future is already set")
35+
self.future = future
36+
self.watcher = FutureWatcher(future, parent=self)
37+
self.watcher.done.connect(self.done)
38+
39+
def cancel(self):
40+
"""
41+
Cancel the task.
42+
43+
Set the `cancelled` field to True and block until the future is done.
44+
"""
45+
# set cancelled state
46+
self.cancelled = True
47+
self.future.cancel()
48+
concurrent.futures.wait([self.future])
49+
50+
def emitProgressUpdate(self, value):
51+
self.progressChanged.emit(value)
52+
53+
def isInterruptionRequested(self):
54+
return self.cancelled
55+
56+
57+
class CancelTaskException(BaseException):
58+
pass
59+
1360

1461
class OWNNLearner(OWBaseLearner):
1562
name = "Neural Network"
@@ -53,11 +100,20 @@ def add_main_layout(self):
53100
label="Alpha:", decimals=5, alignment=Qt.AlignRight,
54101
callback=self.settings_changed, controlWidth=80)
55102
self.max_iter_spin = gui.spin(
56-
box, self, "max_iterations", 10, 300, step=10,
103+
box, self, "max_iterations", 10, 10000, step=10,
57104
label="Max iterations:", orientation=Qt.Horizontal,
58105
alignment=Qt.AlignRight, callback=self.settings_changed,
59106
controlWidth=80)
60107

108+
def setup_layout(self):
109+
super().setup_layout()
110+
111+
self._task = None # type: Optional[Task]
112+
self._executor = ThreadExecutor()
113+
114+
# just a test cancel button
115+
gui.button(self.controlArea, self, "Cancel", callback=self.cancel)
116+
61117
def create_learner(self):
62118
return self.LEARNER(
63119
hidden_layer_sizes=self.get_hidden_layers(),
@@ -81,6 +137,112 @@ def get_hidden_layers(self):
81137
self.hidden_layers_edit.setText("100,")
82138
return layers
83139

140+
def update_model(self):
141+
self.show_fitting_failed(None)
142+
self.model = None
143+
if self.check_data():
144+
self.__update()
145+
else:
146+
self.Outputs.model.send(self.model)
147+
148+
@Slot(float)
149+
def setProgressValue(self, value):
150+
assert self.thread() is QThread.currentThread()
151+
self.progressBarSet(value)
152+
153+
def __update(self):
154+
if self._task is not None:
155+
# First make sure any pending tasks are cancelled.
156+
self.cancel()
157+
assert self._task is None
158+
159+
max_iter = self.learner.kwargs["max_iter"]
160+
161+
# Setup the task state
162+
task = Task()
163+
lastemitted = 0.
164+
165+
def callback(iteration):
166+
nonlocal task # type: Task
167+
nonlocal lastemitted
168+
if task.isInterruptionRequested():
169+
raise CancelTaskException()
170+
progress = round(iteration / max_iter * 100)
171+
if progress != lastemitted:
172+
task.emitProgressUpdate(progress)
173+
lastemitted = progress
174+
175+
# copy to set the callback so that the learner output is not modified
176+
# (currently we can not pass callbacks to learners __call__)
177+
learner = copy.copy(self.learner)
178+
learner.callback = callback
179+
180+
def build_model(data, learner):
181+
try:
182+
return learner(data)
183+
except CancelTaskException:
184+
return None
185+
186+
build_model_func = partial(build_model, self.data, learner)
187+
188+
task.setFuture(self._executor.submit(build_model_func))
189+
task.done.connect(self._task_finished)
190+
task.progressChanged.connect(self.setProgressValue)
191+
192+
self._task = task
193+
self.progressBarInit()
194+
self.setBlocking(True)
195+
196+
@Slot(concurrent.futures.Future)
197+
def _task_finished(self, f):
198+
"""
199+
Parameters
200+
----------
201+
f : Future
202+
The future instance holding the built model
203+
"""
204+
assert self.thread() is QThread.currentThread()
205+
assert self._task is not None
206+
assert self._task.future is f
207+
assert f.done()
208+
self._task.deleteLater()
209+
self._task = None
210+
self.setBlocking(False)
211+
self.progressBarFinished()
212+
213+
try:
214+
self.model = f.result()
215+
except Exception as ex: # pylint: disable=broad-except
216+
# Log the exception with a traceback
217+
log = logging.getLogger()
218+
log.exception(__name__, exc_info=True)
219+
self.model = None
220+
self.show_fitting_failed(ex)
221+
else:
222+
self.model.name = self.learner_name
223+
self.model.instances = self.data
224+
self.Outputs.model.send(self.model)
225+
226+
def cancel(self):
227+
"""
228+
Cancel the current task (if any).
229+
"""
230+
if self._task is not None:
231+
self._task.cancel()
232+
assert self._task.future.done()
233+
# disconnect from the task
234+
self._task.done.disconnect(self._task_finished)
235+
self._task.progressChanged.disconnect(self.setProgressValue)
236+
self._task.deleteLater()
237+
self._task = None
238+
239+
self.progressBarFinished()
240+
self.setBlocking(False)
241+
242+
def onDeleteWidget(self):
243+
self.cancel()
244+
super().onDeleteWidget()
245+
84246

85247
if __name__ == "__main__":
86248
a = QApplication(sys.argv)

Orange/widgets/tests/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,13 +514,16 @@ def test_input_data(self):
514514
self.assertEqual(self.widget.data, None)
515515
self.send_signal("Data", self.data)
516516
self.assertEqual(self.widget.data, self.data)
517+
self.wait_until_stop_blocking()
517518

518519
def test_input_data_disconnect(self):
519520
"""Check widget's data and model after disconnecting data from input"""
520521
self.send_signal("Data", self.data)
521522
self.assertEqual(self.widget.data, self.data)
522523
self.widget.apply_button.button.click()
524+
self.wait_until_stop_blocking()
523525
self.send_signal("Data", None)
526+
self.wait_until_stop_blocking()
524527
self.assertEqual(self.widget.data, None)
525528
self.assertIsNone(self.get_output(self.widget.Outputs.model))
526529

@@ -529,9 +532,11 @@ def test_input_data_learner_adequacy(self):
529532
for inadequate in self.inadequate_dataset:
530533
self.send_signal("Data", inadequate)
531534
self.widget.apply_button.button.click()
535+
self.wait_until_stop_blocking()
532536
self.assertTrue(self.widget.Error.data_error.is_shown())
533537
for valid in self.valid_datasets:
534538
self.send_signal("Data", valid)
539+
self.wait_until_stop_blocking()
535540
self.assertFalse(self.widget.Error.data_error.is_shown())
536541

537542
def test_input_preprocessor(self):
@@ -542,6 +547,7 @@ def test_input_preprocessor(self):
542547
randomize, self.widget.preprocessors,
543548
'Preprocessor not added to widget preprocessors')
544549
self.widget.apply_button.button.click()
550+
self.wait_until_stop_blocking()
545551
self.assertEqual(
546552
(randomize,), self.widget.learner.preprocessors,
547553
'Preprocessors were not passed to the learner')
@@ -551,6 +557,7 @@ def test_input_preprocessors(self):
551557
pp_list = PreprocessorList([Randomize(), RemoveNaNColumns()])
552558
self.send_signal("Preprocessor", pp_list)
553559
self.widget.apply_button.button.click()
560+
self.wait_until_stop_blocking()
554561
self.assertEqual(
555562
(pp_list,), self.widget.learner.preprocessors,
556563
'`PreprocessorList` was not added to preprocessors')
@@ -560,10 +567,12 @@ def test_input_preprocessor_disconnect(self):
560567
randomize = Randomize()
561568
self.send_signal("Preprocessor", randomize)
562569
self.widget.apply_button.button.click()
570+
self.wait_until_stop_blocking()
563571
self.assertEqual(randomize, self.widget.preprocessors)
564572

565573
self.send_signal("Preprocessor", None)
566574
self.widget.apply_button.button.click()
575+
self.wait_until_stop_blocking()
567576
self.assertIsNone(self.widget.preprocessors,
568577
'Preprocessors not removed on disconnect.')
569578

@@ -585,6 +594,7 @@ def test_output_model(self):
585594
self.assertIsNone(self.get_output(self.widget.Outputs.model))
586595
self.send_signal('Data', self.data)
587596
self.widget.apply_button.button.click()
597+
self.wait_until_stop_blocking()
588598
model = self.get_output(self.widget.Outputs.model)
589599
self.assertIsNotNone(model)
590600
self.assertIsInstance(model, self.widget.LEARNER.__returns__)
@@ -598,6 +608,7 @@ def test_output_learner_name(self):
598608
self.widget.name_line_edit.text())
599609
self.widget.name_line_edit.setText(new_name)
600610
self.widget.apply_button.button.click()
611+
self.wait_until_stop_blocking()
601612
self.assertEqual(self.get_output("Learner").name, new_name)
602613

603614
def test_output_model_name(self):
@@ -606,6 +617,7 @@ def test_output_model_name(self):
606617
self.widget.name_line_edit.setText(new_name)
607618
self.send_signal("Data", self.data)
608619
self.widget.apply_button.button.click()
620+
self.wait_until_stop_blocking()
609621
self.assertEqual(self.get_output(self.widget.Outputs.model).name, new_name)
610622

611623
def _get_param_value(self, learner, param):
@@ -626,6 +638,7 @@ def test_parameters_default(self):
626638
for dataset in self.valid_datasets:
627639
self.send_signal("Data", dataset)
628640
self.widget.apply_button.button.click()
641+
self.wait_until_stop_blocking()
629642
for parameter in self.parameters:
630643
# Skip if the param isn't used for the given data type
631644
if self._should_check_parameter(parameter, dataset):
@@ -639,6 +652,7 @@ def test_parameters(self):
639652
# to only certain problem types
640653
for dataset in self.valid_datasets:
641654
self.send_signal("Data", dataset)
655+
self.wait_until_stop_blocking()
642656

643657
for parameter in self.parameters:
644658
# Skip if the param isn't used for the given data type
@@ -650,6 +664,7 @@ def test_parameters(self):
650664
for value in parameter.values:
651665
parameter.set_value(value)
652666
self.widget.apply_button.button.click()
667+
self.wait_until_stop_blocking()
653668
param = self._get_param_value(self.widget.learner, parameter)
654669
self.assertEqual(
655670
param, parameter.get_value(),
@@ -674,6 +689,7 @@ def test_params_trigger_settings_changed(self):
674689
"""Check that the learner gets updated whenever a param is changed."""
675690
for dataset in self.valid_datasets:
676691
self.send_signal("Data", dataset)
692+
self.wait_until_stop_blocking()
677693

678694
for parameter in self.parameters:
679695
# Skip if the param isn't used for the given data type

0 commit comments

Comments
 (0)