Skip to content

Commit 41f6906

Browse files
authored
Merge pull request #5710 from janezd/baselearnerwidget-pp-warning
[ENH] Learner widgets: Inform about potential problems when overriding preprocessors
2 parents 5d83588 + e44754d commit 41f6906

File tree

7 files changed

+133
-26
lines changed

7 files changed

+133
-26
lines changed

Orange/widgets/model/owadaboost.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ def set_base_learner(self, learner):
102102
self.base_estimator = learner or self.DEFAULT_BASE_ESTIMATOR
103103
self.base_label.setText(
104104
"Base estimator: %s" % self.base_estimator.name.title())
105-
if self.auto_apply:
106-
self.apply()
105+
self.learner = self.model = None
107106

108107
def get_learner_parameters(self):
109108
return (("Base estimator", self.base_estimator),

Orange/widgets/model/owcalibratedlearner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def add_main_layout(self):
6262
def set_learner(self, learner):
6363
self.base_learner = learner
6464
self._set_default_name()
65-
self.unconditional_apply()
65+
self.learner = self.model = None
6666

6767
def _set_default_name(self):
6868

Orange/widgets/model/owcurvefit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def __insert_into_expression(self, what: str, offset=0):
395395

396396
def set_data(self, data: Optional[Table]):
397397
self.Warning.data_missing(shown=not bool(data))
398+
self.learner = None
398399
super().set_data(data)
399400
self.__clear()
400401

@@ -419,7 +420,7 @@ def handleNewSignals(self):
419420
self.__init_models()
420421
self.__enable_controls()
421422
self.__set_pending()
422-
self.unconditional_apply()
423+
super().handleNewSignals()
423424

424425
def __preprocess_data(self):
425426
self.__pp_data = preprocess(self.data, self.preprocessors)

Orange/widgets/model/owlinearregression.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ def add_main_layout(self):
9494
self.controls.alpha_index.setEnabled(self.reg_type != self.OLS)
9595
self.l2_ratio_slider.setEnabled(self.reg_type == self.Elastic)
9696

97-
def handleNewSignals(self):
98-
self.apply()
99-
10097
def _intercept_changed(self):
10198
self.apply()
10299

Orange/widgets/model/owstack.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,26 @@ def add_main_layout(self):
3333
@Inputs.learners
3434
def set_learner(self, index: int, learner: Learner):
3535
self.learners[index] = learner
36+
self._invalidate()
3637

3738
@Inputs.learners.insert
3839
def insert_learner(self, index, learner):
3940
self.learners.insert(index, learner)
41+
self._invalidate()
4042

4143
@Inputs.learners.remove
4244
def remove_learner(self, index):
4345
self.learners.pop(index)
46+
self._invalidate()
4447

4548
@Inputs.aggregate
4649
def set_aggregate(self, aggregate):
4750
self.aggregate = aggregate
51+
self._invalidate()
4852

49-
def handleNewSignals(self):
50-
super().handleNewSignals()
51-
self.apply()
53+
def _invalidate(self):
54+
self.learner = self.model = None
55+
# ... and handleNewSignals will do the rest
5256

5357
def create_learner(self):
5458
if not self.learners:

Orange/widgets/utils/owlearnerwidget.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ class Error(OWWidget.Error):
7979
class Warning(OWWidget.Warning):
8080
outdated_learner = Msg("Press Apply to submit changes.")
8181

82+
class Information(OWWidget.Information):
83+
ignored_preprocessors = Msg(
84+
"Ignoring default preprocessing.\n"
85+
"Default preprocessing, such as scaling, one-hot encoding and "
86+
"treatment of missing data, has been replaced with user-specified "
87+
"preprocessors. Problems may occur if these are inadequate "
88+
"for the given data.")
89+
8290
class Inputs:
8391
data = Input("Data", Table)
8492
preprocessor = Input("Preprocessor", Preprocess)
@@ -90,6 +98,8 @@ class Outputs:
9098

9199
OUTPUT_MODEL_NAME = Outputs.model.name # Attr for backcompat w/ self.send() code
92100

101+
_SEND, _SOFT, _UPDATE = range(3)
102+
93103
def __init__(self, preprocessors=None):
94104
super().__init__()
95105
self.__default_learner_name = ""
@@ -99,6 +109,7 @@ def __init__(self, preprocessors=None):
99109
self.model = None
100110
self.preprocessors = preprocessors
101111
self.outdated_settings = False
112+
self.__apply_level = []
102113

103114
self.setup_layout()
104115
QTimer.singleShot(0, getattr(self, "unconditional_apply", self.apply))
@@ -144,7 +155,8 @@ def set_default_learner_name(self, name: str) -> None:
144155
@Inputs.preprocessor
145156
def set_preprocessor(self, preprocessor):
146157
self.preprocessors = preprocessor
147-
self.apply()
158+
# invalidate learner and model, so handleNewSignals will renew them
159+
self.learner = self.model = None
148160

149161
@Inputs.data
150162
@check_sql_input
@@ -164,23 +176,50 @@ def set_data(self, data):
164176
"Select one with the Select Columns widget.")
165177
self.data = None
166178

167-
self.update_model()
179+
# invalidate the model so that handleNewSignals will update it
180+
self.model = None
181+
168182

169183
def apply(self):
184+
level, self.__apply_level = max(self.__apply_level, default=self._UPDATE), []
170185
"""Applies learner and sends new model."""
171-
self.update_learner()
172-
self.update_model()
186+
if level == self._SEND:
187+
self._send_learner()
188+
self._send_model()
189+
elif level == self._UPDATE:
190+
self.update_learner()
191+
self.update_model()
192+
else:
193+
self.learner or self.update_learner()
194+
self.model or self.update_model()
195+
196+
def apply_as(self, level, unconditional=False):
197+
self.__apply_level.append(level)
198+
if unconditional:
199+
self.unconditional_apply()
200+
else:
201+
self.apply()
173202

174203
def update_learner(self):
175204
self.learner = self.create_learner()
176205
if self.learner and issubclass(self.LEARNER, Fitter):
177206
self.learner.use_default_preprocessors = True
178207
if self.learner is not None:
179208
self.learner.name = self.effective_learner_name()
209+
self._send_learner()
210+
211+
def _send_learner(self):
180212
self.Outputs.learner.send(self.learner)
181213
self.outdated_settings = False
182214
self.Warning.outdated_learner.clear()
183215

216+
def handleNewSignals(self):
217+
self.apply_as(self._SOFT, True)
218+
self.Information.ignored_preprocessors(
219+
shown=not getattr(self.learner, "use_default_preprocessors", False)
220+
and getattr(self.LEARNER, "preprocessors", False)
221+
and self.preprocessors is not None)
222+
184223
def show_fitting_failed(self, exc):
185224
"""Show error when fitting fails.
186225
Derived widgets can override this to show more specific messages."""
@@ -197,6 +236,9 @@ def update_model(self):
197236
else:
198237
self.model.name = self.learner_name or self.captionTitle
199238
self.model.instances = self.data
239+
self._send_model()
240+
241+
def _send_model(self):
200242
self.Outputs.model.send(self.model)
201243

202244
def check_data(self):
@@ -223,15 +265,12 @@ def settings_changed(self, *args, **kwargs):
223265
self.Warning.outdated_learner(shown=not self.auto_apply)
224266
self.apply()
225267

226-
def _change_name(self, instance, output):
227-
if instance:
228-
instance.name = self.effective_learner_name()
229-
if self.auto_apply:
230-
output.send(instance)
231-
232268
def learner_name_changed(self):
233-
self._change_name(self.learner, self.Outputs.learner)
234-
self._change_name(self.model, self.Outputs.model)
269+
if self.model is not None:
270+
self.model.name = self.effective_learner_name()
271+
if self.learner is not None:
272+
self.learner.name = self.effective_learner_name()
273+
self.apply_as(self._SEND)
235274

236275
def effective_learner_name(self):
237276
"""Return the effective learner name."""
@@ -272,7 +311,6 @@ def add_main_layout(self):
272311
Override this method for laying out any learner-specific parameter controls.
273312
See setup_layout() method for execution order.
274313
"""
275-
pass
276314

277315
def add_classification_layout(self, box):
278316
"""Creates layout for classification specific options.
@@ -281,7 +319,6 @@ def add_classification_layout(self, box):
281319
and regression learners require different options.
282320
See `setup_layout()` method for execution order.
283321
"""
284-
pass
285322

286323
def add_regression_layout(self, box):
287324
"""Creates layout for regression specific options.
@@ -290,7 +327,6 @@ def add_regression_layout(self, box):
290327
and regression learners require different options.
291328
See `setup_layout()` method for execution order.
292329
"""
293-
pass
294330

295331
def add_learner_name_widget(self):
296332
self.name_line_edit = gui.lineEdit(

Orange/widgets/utils/tests/test_owlearnerwidget.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import Mock
1+
from unittest.mock import Mock, patch
22

33
import scipy.sparse as sp
44

@@ -218,3 +218,73 @@ def check_name(name):
218218
check_name("Bar")
219219
w.set_default_learner_name("")
220220
check_name("Blarg")
221+
222+
def test_preprocessor_warning(self):
223+
class TestLearnerNoPreprocess(Learner):
224+
name = "Test"
225+
__returns__ = Mock()
226+
227+
class TestWidgetNoPreprocess(OWBaseLearner):
228+
name = "Test"
229+
LEARNER = TestLearnerNoPreprocess
230+
231+
class TestLearnerPreprocess(Learner):
232+
name = "Test"
233+
preprocessors = [Mock()]
234+
__returns__ = Mock()
235+
236+
class TestWidgetPreprocess(OWBaseLearner):
237+
name = "Test"
238+
LEARNER = TestLearnerPreprocess
239+
240+
class TestFitterPreprocess(Fitter):
241+
name = "Test"
242+
preprocessors = [Mock()]
243+
__returns__ = Mock()
244+
245+
class TestWidgetPreprocessFit(OWBaseLearner):
246+
name = "Test"
247+
LEARNER = TestFitterPreprocess
248+
249+
wno = self.create_widget(TestWidgetNoPreprocess)
250+
wyes = self.create_widget(TestWidgetPreprocess)
251+
wfit = self.create_widget(TestWidgetPreprocessFit)
252+
253+
self.assertFalse(wno.Information.ignored_preprocessors.is_shown())
254+
self.assertFalse(wyes.Information.ignored_preprocessors.is_shown())
255+
self.assertFalse(wfit.Information.ignored_preprocessors.is_shown())
256+
257+
pp = continuize.Continuize()
258+
self.send_signal(wno.Inputs.preprocessor, pp)
259+
self.send_signal(wyes.Inputs.preprocessor, pp)
260+
self.send_signal(wfit.Inputs.preprocessor, pp)
261+
262+
self.assertFalse(wno.Information.ignored_preprocessors.is_shown())
263+
self.assertTrue(wyes.Information.ignored_preprocessors.is_shown())
264+
self.assertFalse(wfit.Information.ignored_preprocessors.is_shown())
265+
266+
self.send_signal(wno.Inputs.preprocessor, None)
267+
self.send_signal(wyes.Inputs.preprocessor, None)
268+
self.send_signal(wfit.Inputs.preprocessor, None)
269+
270+
self.assertFalse(wno.Information.ignored_preprocessors.is_shown())
271+
self.assertFalse(wyes.Information.ignored_preprocessors.is_shown())
272+
self.assertFalse(wfit.Information.ignored_preprocessors.is_shown())
273+
274+
def test_multiple_sends(self):
275+
class TestLearner(Learner):
276+
name = "Test"
277+
__returns__ = Mock()
278+
279+
class TestWidget(OWBaseLearner):
280+
name = "Test"
281+
LEARNER = TestLearner
282+
283+
widget = self.create_widget(TestWidget)
284+
pp = continuize.Continuize()
285+
with patch.object(widget.Outputs.learner, "send") as model_send, \
286+
patch.object(widget.Outputs.model, "send") as learner_send:
287+
self.send_signals([(widget.Inputs.data, self.iris),
288+
(widget.Inputs.preprocessor, pp)])
289+
learner_send.assert_called_once()
290+
model_send.assert_called_once()

0 commit comments

Comments
 (0)