Skip to content

Commit 9b3faf4

Browse files
authored
Merge pull request #1977 from pavlin-policar/fix-fitter-recursion
[FIX] Fitter: Fix infinite recursion in __getattr__
2 parents ebf8354 + 49188ee commit 9b3faf4

File tree

9 files changed

+93
-61
lines changed

9 files changed

+93
-61
lines changed

Orange/base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class Learner(_ReprableWithPreprocessors):
7272
This property is needed mainly because of the `Fitter` class, which can
7373
not know in advance, which preprocessors it will need to use. Therefore
7474
this resolves the active preprocessors using a lazy approach.
75+
params : dict
76+
The params that the learner is constructed with.
7577
7678
"""
7779
supports_multiclass = False
@@ -114,19 +116,21 @@ def __call__(self, data):
114116
self.__class__.__name__)
115117

116118
self.domain = data.domain
117-
118-
if type(self).fit is Learner.fit:
119-
model = self.fit_storage(data)
120-
else:
121-
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
122-
model = self.fit(X, Y, W)
119+
model = self._fit_model(data)
123120
model.domain = data.domain
124121
model.supports_multiclass = self.supports_multiclass
125122
model.name = self.name
126123
model.original_domain = origdomain
127124
model.original_data = origdata
128125
return model
129126

127+
def _fit_model(self, data):
128+
if type(self).fit is Learner.fit:
129+
return self.fit_storage(data)
130+
else:
131+
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
132+
return self.fit(X, Y, W)
133+
130134
def preprocess(self, data):
131135
"""Apply the `preprocessors` to the data"""
132136
for pp in self.active_preprocessors:

Orange/classification/tree.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,12 @@ def __init__(
5959
min_samples_leaf=1, min_samples_split=2, sufficient_majority=0.95,
6060
**kwargs):
6161
super().__init__(*args, **kwargs)
62-
self.binarize = binarize
63-
self.min_samples_leaf = min_samples_leaf
64-
self.min_samples_split = min_samples_split
65-
self.sufficient_majority = sufficient_majority
66-
self.max_depth = max_depth
62+
self.params = {}
63+
self.binarize = self.params['binarize'] = binarize
64+
self.min_samples_leaf = self.params['min_samples_leaf'] = min_samples_leaf
65+
self.min_samples_split = self.params['min_samples_split'] = min_samples_split
66+
self.sufficient_majority = self.params['sufficient_majority'] = sufficient_majority
67+
self.max_depth = self.params['max_depth'] = max_depth
6768

6869
def _select_attr(self, data):
6970
"""Select the attribute for the next split.

Orange/modelling/base.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,32 @@ def __init__(self, preprocessors=None, **kwargs):
4141
self.kwargs = kwargs
4242
# Make sure to pass preprocessor params to individual learners
4343
self.kwargs['preprocessors'] = preprocessors
44-
self.problem_type = None
4544
self.__learners = {self.CLASSIFICATION: None, self.REGRESSION: None}
4645

47-
def __call__(self, data):
48-
# Set the appropriate problem type from the data
49-
self.problem_type = self.CLASSIFICATION if \
50-
data.domain.has_discrete_class else self.REGRESSION
51-
return self.get_learner(self.problem_type)(data)
46+
def _fit_model(self, data):
47+
if data.domain.has_discrete_class:
48+
learner = self.get_learner(self.CLASSIFICATION)
49+
else:
50+
learner = self.get_learner(self.REGRESSION)
51+
52+
if type(self).fit is Learner.fit:
53+
return learner.fit_storage(data)
54+
else:
55+
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
56+
return learner.fit(X, Y, W)
5257

5358
def get_learner(self, problem_type):
54-
"""Get the learner for a given problem type."""
59+
"""Get the learner for a given problem type.
60+
61+
Returns
62+
-------
63+
Learner
64+
The appropriate learner for the given problem type.
65+
66+
"""
5567
# Prevent trying to access the learner when problem type is None
5668
if problem_type not in self.__fits__:
57-
# We're mostly called from __getattr__ via getattr, so we should
58-
# raise AttributeError instead of TypeError
59-
raise AttributeError("No learner to handle '{}'".format(problem_type))
69+
raise TypeError("No learner to handle '{}'".format(problem_type))
6070
if self.__learners[problem_type] is None:
6171
learner = self.__fits__[problem_type](**self.__kwargs(problem_type))
6272
learner.use_default_preprocessors = self.use_default_preprocessors
@@ -66,7 +76,7 @@ def get_learner(self, problem_type):
6676
def __kwargs(self, problem_type):
6777
learner_kwargs = set(
6878
self.__fits__[problem_type].__init__.__code__.co_varnames[1:])
69-
changed_kwargs = self._change_kwargs(self.kwargs, self.problem_type)
79+
changed_kwargs = self._change_kwargs(self.kwargs, problem_type)
7080
return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs}
7181

7282
def _change_kwargs(self, kwargs, problem_type):
@@ -93,8 +103,12 @@ def supports_weights(self):
93103
hasattr(self.get_learner(self.REGRESSION), 'supports_weights')
94104
and self.get_learner(self.REGRESSION).supports_weights)
95105

96-
def __getattr__(self, item):
97-
# Make parameters accessible on the learner for simpler testing
98-
if item in self.kwargs:
99-
return self.kwargs[item]
100-
return getattr(self.get_learner(self.problem_type), item)
106+
@property
107+
def params(self):
108+
raise TypeError(
109+
'A fitter does not have its own params. If you need to access '
110+
'learner params, please use the `get_params` method.')
111+
112+
def get_params(self, problem_type):
113+
"""Access the specific learner params of a given learner."""
114+
return self.get_learner(problem_type).params

Orange/regression/tree.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def __init__(
4848
binarize=False, min_samples_leaf=1, min_samples_split=2,
4949
max_depth=None, **kwargs):
5050
super().__init__(*args, **kwargs)
51-
self.binarize = binarize
52-
self.min_samples_leaf = min_samples_leaf
53-
self.min_samples_split = min_samples_split
54-
self.max_depth = max_depth
51+
self.params = {}
52+
self.binarize = self.params['binarity'] = binarize
53+
self.min_samples_leaf = self.params['min_samples_leaf'] = min_samples_leaf
54+
self.min_samples_split = self.params['min_samples_split'] = min_samples_split
55+
self.max_depth = self.params['max_depth'] = max_depth
5556

5657
def _select_attr(self, data):
5758
"""Select the attribute for the next split.

Orange/tests/test_fitter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import unittest
2-
from unittest.mock import Mock
2+
from unittest.mock import Mock, patch
33

44
from Orange.classification.base_classification import LearnerClassification
55
from Orange.data import Table
6+
from Orange.evaluation import CrossValidation
67
from Orange.modelling import Fitter
78
from Orange.preprocess import Randomize
89
from Orange.regression.base_regression import LearnerRegression
@@ -92,14 +93,13 @@ class DummyFitter(Fitter):
9293
fitter = DummyFitter()
9394
self.assertEqual(fitter.get_learner(Fitter.CLASSIFICATION).param, 1)
9495
self.assertEqual(fitter.get_learner(Fitter.REGRESSION).param, 2)
95-
self.assertEqual(fitter.name, 'dummy')
9696

9797
# Pass specific params
9898
try:
9999
fitter = DummyFitter(classification_param=10, regression_param=20)
100100
self.assertEqual(fitter.get_learner(Fitter.CLASSIFICATION).param, 10)
101101
self.assertEqual(fitter.get_learner(Fitter.REGRESSION).param, 20)
102-
except AttributeError:
102+
except TypeError:
103103
self.fail('Fitter did not properly distribute params to learners')
104104

105105
def test_error_for_data_type_with_no_learner(self):
@@ -126,3 +126,7 @@ def test_correctly_sets_preprocessors_on_learner(self):
126126
self.assertEqual(
127127
tuple(learner.active_preprocessors), (pp,),
128128
'Fitter did not properly pass its preprocessors to its learners')
129+
130+
def test_n_jobs_fitting(self):
131+
with patch('Orange.evaluation.testing.CrossValidation._MIN_NJOBS_X_SIZE', 1):
132+
CrossValidation(self.heart_disease, [DummyFitter()], k=5, n_jobs=5)

Orange/widgets/regression/tests/test_owadaboostregression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def setUp(self):
1919
self.valid_datasets = (self.data,)
2020
losses = [loss.lower() for loss in self.widget.losses]
2121
self.parameters = [
22-
ParameterMapping('loss', self.widget.reg_algorithm_combo, losses),
22+
ParameterMapping('loss', self.widget.reg_algorithm_combo, losses,
23+
problem_type='regression'),
2324
ParameterMapping('learning_rate', self.widget.learning_rate_spin),
2425
ParameterMapping('n_estimators', self.widget.n_estimators_spin)]
2526

Orange/widgets/regression/tests/test_owsgdregression.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ def setUp(self):
1515
self.valid_datasets = (self.housing,)
1616
self.parameters = [
1717
ParameterMapping('loss', self.widget.reg_loss_function_combo,
18-
list(zip(*self.widget.reg_losses))[1]),
19-
ParameterMapping.from_attribute(self.widget, 'reg_epsilon', 'epsilon'),
18+
list(zip(*self.widget.reg_losses))[1],
19+
problem_type='regression'),
20+
ParameterMapping('epsilon', self.widget.reg_epsilon_spin,
21+
problem_type='regression'),
2022
ParameterMapping('penalty', self.widget.penalty_combo,
2123
list(zip(*self.widget.penalties))[1]),
2224
ParameterMapping.from_attribute(self.widget, 'alpha'),

Orange/widgets/regression/tests/test_owsvmregression.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ def setter(value):
2828
gamma_spin.setValue(value)
2929

3030
self.parameters = [
31-
ParameterMapping("C", self.widget.C_spin),
32-
ParameterMapping("epsilon", self.widget.epsilon_spin),
31+
ParameterMapping("C", self.widget.C_spin,
32+
problem_type="regression"),
33+
ParameterMapping("epsilon", self.widget.epsilon_spin,
34+
problem_type="regression"),
3335
ParameterMapping("gamma", self.widget._kernel_params[0],
3436
values=values, setter=setter, getter=getter),
3537
ParameterMapping("coef0", self.widget._kernel_params[1]),
@@ -44,8 +46,10 @@ def test_parameters_svr_type(self):
4446
# setChecked(True) does not trigger callback event
4547
self.widget.nu_radio.click()
4648
self.assertEqual(self.widget.svm_type, OWSVM.Nu_SVM)
47-
self.parameters[0] = ParameterMapping("C", self.widget.nu_C_spin)
48-
self.parameters[1] = ParameterMapping("nu", self.widget.nu_spin)
49+
self.parameters[0] = ParameterMapping("C", self.widget.nu_C_spin,
50+
problem_type="regression")
51+
self.parameters[1] = ParameterMapping("nu", self.widget.nu_spin,
52+
problem_type="regression")
4953
self.test_parameters()
5054

5155
def test_kernel_equation(self):

Orange/widgets/tests/base.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -465,30 +465,33 @@ def test_output_model_name(self):
465465
self.widget.apply_button.button.click()
466466
self.assertEqual(self.get_output(self.model_name).name, new_name)
467467

468+
def _get_param_value(self, learner, param):
469+
if isinstance(learner, Fitter):
470+
# Both is just a was to indicate to the tests, fitters don't
471+
# actually support this
472+
if param.problem_type == "both":
473+
problem_type = learner.CLASSIFICATION
474+
else:
475+
problem_type = param.problem_type
476+
return learner.get_params(problem_type).get(param.name)
477+
else:
478+
return learner.params.get(param.name)
479+
468480
def test_parameters_default(self):
469481
"""Check if learner's parameters are set to default (widget's) values
470482
"""
471483
for dataset in self.valid_datasets:
472484
self.send_signal("Data", dataset)
473485
self.widget.apply_button.button.click()
474-
if hasattr(self.widget.learner, "params"):
475-
learner_params = self.widget.learner.params
476-
for parameter in self.parameters:
477-
# Skip if the param isn't used for the given data type
478-
if self._should_check_parameter(parameter, dataset):
479-
self.assertEqual(learner_params.get(parameter.name),
480-
parameter.get_value())
486+
for parameter in self.parameters:
487+
# Skip if the param isn't used for the given data type
488+
if self._should_check_parameter(parameter, dataset):
489+
self.assertEqual(
490+
self._get_param_value(self.widget.learner, parameter),
491+
parameter.get_value())
481492

482493
def test_parameters(self):
483494
"""Check learner and model for various values of all parameters"""
484-
485-
def get_value(learner, name):
486-
# Handle SKL and skl-like learners, and non-SKL learners
487-
if hasattr(learner, "params"):
488-
return learner.params.get(name)
489-
else:
490-
return getattr(learner, name)
491-
492495
# Test params on every valid dataset, since some attributes may apply
493496
# to only certain problem types
494497
for dataset in self.valid_datasets:
@@ -504,24 +507,22 @@ def get_value(learner, name):
504507
for value in parameter.values:
505508
parameter.set_value(value)
506509
self.widget.apply_button.button.click()
507-
param = get_value(self.widget.learner, parameter.name)
510+
param = self._get_param_value(self.widget.learner, parameter)
508511
self.assertEqual(
509512
param, parameter.get_value(),
510513
"Mismatching setting for parameter '%s'" % parameter)
511514
self.assertEqual(
512515
param, value,
513516
"Mismatching setting for parameter '%s'" % parameter)
514-
param = get_value(self.get_output("Learner"),
515-
parameter.name)
517+
param = self._get_param_value(self.get_output("Learner"), parameter)
516518
self.assertEqual(
517519
param, value,
518520
"Mismatching setting for parameter '%s'" % parameter)
519521

520522
if issubclass(self.widget.LEARNER, SklModel):
521523
model = self.get_output(self.model_name)
522524
if model is not None:
523-
self.assertEqual(get_value(model, parameter.name),
524-
value)
525+
self.assertEqual(self._get_param_value(model, parameter), value)
525526
self.assertFalse(self.widget.Error.active)
526527
else:
527528
self.assertTrue(self.widget.Error.active)

0 commit comments

Comments
 (0)