Skip to content

Commit e7c89a2

Browse files
committed
owneuralnetwork: connect callbacks to NN through n_iters_
MLPClassifier n_iters_ was made a property which calls a callback.
1 parent 3fc7292 commit e7c89a2

File tree

5 files changed

+57
-27
lines changed

5 files changed

+57
-27
lines changed

Orange/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,11 @@ def __call__(self, data):
383383
m.params = self.params
384384
return m
385385

386+
def _initialize_wrapped(self):
387+
return self.__wraps__(**self.params)
388+
386389
def fit(self, X, Y, W=None):
387-
clf = self.__wraps__(**self.params)
390+
clf = self._initialize_wrapped()
388391
Y = Y.reshape(-1)
389392
if W is None or not self.supports_weights:
390393
return self.__returns__(clf.fit(X, Y))

Orange/classification/neural_network.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,28 @@
55
__all__ = ["NNClassificationLearner"]
66

77

8+
class NIterCallbackMixin:
9+
orange_callback = None
10+
11+
@property
12+
def n_iter_(self):
13+
return self.__orange_n_iter
14+
15+
@n_iter_.setter
16+
def n_iter_(self, v):
17+
self.__orange_n_iter = v
18+
if self.orange_callback:
19+
self.orange_callback(v)
20+
21+
22+
class MLPClassifierWCallback(skl_nn.MLPClassifier, NIterCallbackMixin):
23+
pass
24+
25+
826
class NNClassificationLearner(NNBase, SklLearner):
9-
__wraps__ = skl_nn.MLPClassifier
27+
__wraps__ = MLPClassifierWCallback
28+
29+
def _initialize_wrapped(self):
30+
clf = SklLearner._initialize_wrapped(self)
31+
clf.orange_callback = getattr(self, "callback", None)
32+
return clf

Orange/modelling/neural_network.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@
88
class NNLearner(SklFitter):
99
__fits__ = {'classification': NNClassificationLearner,
1010
'regression': NNRegressionLearner}
11+
12+
callback = None
13+
14+
def get_learner(self, problem_type):
15+
learner = super().get_learner(problem_type)
16+
learner.callback = self.callback
17+
return learner
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
import sklearn.neural_network as skl_nn
22
from Orange.base import NNBase
33
from Orange.regression import SklLearner
4+
from Orange.classification.neural_network import NIterCallbackMixin
45

56
__all__ = ["NNRegressionLearner"]
67

78

9+
class MLPRegressorWCallback(skl_nn.MLPRegressor, NIterCallbackMixin):
10+
pass
11+
12+
813
class NNRegressionLearner(NNBase, SklLearner):
9-
__wraps__ = skl_nn.MLPRegressor
14+
__wraps__ = MLPRegressorWCallback
15+
16+
def _initialize_wrapped(self):
17+
clf = SklLearner._initialize_wrapped(self)
18+
clf.orange_callback = getattr(self, "callback", None)
19+
return clf

Orange/widgets/model/owneuralnetwork.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from functools import partial
2+
import copy
23
import logging
34
import re
45
import sys
5-
from unittest.mock import patch
66
import concurrent.futures
77

88
from AnyQt.QtWidgets import QApplication, qApp
@@ -20,7 +20,6 @@
2020
)
2121

2222

23-
2423
class Task:
2524
"""
2625
A class that will hold the state for an learner evaluation.
@@ -152,32 +151,20 @@ def __update(self):
152151

153152
max_iter = self.learner.kwargs["max_iter"]
154153

155-
def callback(iteration=None):
154+
def callback(iteration):
156155
if task.cancelled:
157156
raise CancelThreadException() # this stop the thread
158-
if iteration is not None:
159-
set_progress(iteration/max_iter*100)
160-
161-
def print_callback(*args, **kwargs):
162-
iters = None
163-
# try to parse iteration number
164-
if args and args[0] and isinstance(args[0], str):
165-
find = re.findall(r"Iteration (\d+)", args[0])
166-
if find:
167-
iters = int(find[0])
168-
callback(iters)
157+
set_progress(iteration/max_iter*100)
158+
159+
# copy to set the callback so that the learner output is not modified
160+
# (currently we can not pass callbacks to learners __call__)
161+
learner = copy.copy(self.learner)
162+
learner.callback = callback
169163

170164
def build_model(data, learner):
171-
if learner.kwargs["solver"] != "lbfgs":
172-
# enable verbose printouts within scikit and redirect them
173-
with patch.dict(learner.kwargs, {"verbose": True}),\
174-
patch("builtins.print", print_callback):
175-
return learner(data)
176-
else:
177-
# lbfgs solver uses different mechanism
178-
return learner(data)
179-
180-
build_model_func = partial(build_model, self.data, self.learner)
165+
return learner(data)
166+
167+
build_model_func = partial(build_model, self.data, learner)
181168

182169
self.progressBarInit()
183170

0 commit comments

Comments
 (0)