Skip to content

Commit dac6438

Browse files
Learner: The attribute is now expected everywhere
1 parent e8423b3 commit dac6438

File tree

7 files changed

+54
-27
lines changed

7 files changed

+54
-27
lines changed

Orange/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class Learner(_ReprableWithPreprocessors):
7272
This property is needed mainly because of the `Fitter` class, which can
7373
not know in advance, which preprocessors it will need to use. Therefore
7474
this resolves the active preprocessors using a lazy approach.
75+
params : dict
76+
The params that the learner is constructed with.
7577
7678
"""
7779
supports_multiclass = False

Orange/classification/tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
self.min_samples_split = min_samples_split
6565
self.sufficient_majority = sufficient_majority
6666
self.max_depth = max_depth
67+
self.params = {k: v for k, v in vars().items()
68+
if k not in ('args', 'kwargs')}
6769

6870
def _select_attr(self, data):
6971
"""Select the attribute for the next split.

Orange/modelling/base.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ class Fitter(Learner, metaclass=FitterMeta):
3838

3939
def __init__(self, preprocessors=None, **kwargs):
4040
super().__init__(preprocessors=preprocessors)
41-
self.params = kwargs
41+
self.kwargs = kwargs
4242
# Make sure to pass preprocessor params to individual learners
43-
self.params['preprocessors'] = preprocessors
43+
self.kwargs['preprocessors'] = preprocessors
4444
self.__learners = {self.CLASSIFICATION: None, self.REGRESSION: None}
4545

4646
def _fit_model(self, data):
@@ -56,7 +56,14 @@ def _fit_model(self, data):
5656
return learner.fit(X, Y, W)
5757

5858
def get_learner(self, problem_type):
59-
"""Get the learner for a given problem type."""
59+
"""Get the learner for a given problem type.
60+
61+
Returns
62+
-------
63+
Learner
64+
The appropriate learner for the given problem type.
65+
66+
"""
6067
# Prevent trying to access the learner when problem type is None
6168
if problem_type not in self.__fits__:
6269
raise TypeError("No learner to handle '{}'".format(problem_type))
@@ -69,7 +76,7 @@ def get_learner(self, problem_type):
6976
def __kwargs(self, problem_type):
7077
learner_kwargs = set(
7178
self.__fits__[problem_type].__init__.__code__.co_varnames[1:])
72-
changed_kwargs = self._change_kwargs(self.params, problem_type)
79+
changed_kwargs = self._change_kwargs(self.kwargs, problem_type)
7380
return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs}
7481

7582
def _change_kwargs(self, kwargs, problem_type):
@@ -95,3 +102,13 @@ def supports_weights(self):
95102
and self.get_learner(self.CLASSIFICATION).supports_weights) and (
96103
hasattr(self.get_learner(self.REGRESSION), 'supports_weights')
97104
and self.get_learner(self.REGRESSION).supports_weights)
105+
106+
@property
107+
def params(self):
108+
raise TypeError(
109+
'A fitter does not have its own params. If you need to access '
110+
'learner params, please use the `get_params` method.')
111+
112+
def get_params(self, problem_type):
113+
"""Access the specific learner params of a given learner."""
114+
return self.get_learner(problem_type).params

Orange/regression/tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(
5252
self.min_samples_leaf = min_samples_leaf
5353
self.min_samples_split = min_samples_split
5454
self.max_depth = max_depth
55+
self.params = {k: v for k, v in vars().items()
56+
if k not in ('args', 'kwargs')}
5557

5658
def _select_attr(self, data):
5759
"""Select the attribute for the next split.

Orange/widgets/regression/tests/test_owadaboostregression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def setUp(self):
1919
self.valid_datasets = (self.data,)
2020
losses = [loss.lower() for loss in self.widget.losses]
2121
self.parameters = [
22-
ParameterMapping('loss', self.widget.reg_algorithm_combo, losses),
22+
ParameterMapping('loss', self.widget.reg_algorithm_combo, losses,
23+
problem_type='regression'),
2324
ParameterMapping('learning_rate', self.widget.learning_rate_spin),
2425
ParameterMapping('n_estimators', self.widget.n_estimators_spin)]
2526

Orange/widgets/regression/tests/test_owsgdregression.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ def setUp(self):
1414
self.valid_datasets = (self.housing,)
1515
self.parameters = [
1616
ParameterMapping('loss', self.widget.reg_loss_function_combo,
17-
list(zip(*self.widget.reg_losses))[1]),
18-
ParameterMapping.from_attribute(self.widget, 'reg_epsilon', 'epsilon'),
17+
list(zip(*self.widget.reg_losses))[1],
18+
problem_type='regression'),
19+
ParameterMapping('epsilon', self.widget.reg_epsilon_spin,
20+
problem_type='regression'),
1921
ParameterMapping('penalty', self.widget.penalty_combo,
2022
list(zip(*self.widget.penalties))[1]),
2123
ParameterMapping.from_attribute(self.widget, 'alpha'),

Orange/widgets/tests/base.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -465,30 +465,33 @@ def test_output_model_name(self):
465465
self.widget.apply_button.button.click()
466466
self.assertEqual(self.get_output(self.model_name).name, new_name)
467467

468+
def _get_param_value(self, learner, param):
469+
if isinstance(learner, Fitter):
470+
# Both is just a was to indicate to the tests, fitters don't
471+
# actually support this
472+
if param.problem_type == "both":
473+
problem_type = learner.CLASSIFICATION
474+
else:
475+
problem_type = param.problem_type
476+
return learner.get_params(problem_type).get(param.name)
477+
else:
478+
return learner.params.get(param.name)
479+
468480
def test_parameters_default(self):
469481
"""Check if learner's parameters are set to default (widget's) values
470482
"""
471483
for dataset in self.valid_datasets:
472484
self.send_signal("Data", dataset)
473485
self.widget.apply_button.button.click()
474-
if hasattr(self.widget.learner, "params"):
475-
learner_params = self.widget.learner.params
476-
for parameter in self.parameters:
477-
# Skip if the param isn't used for the given data type
478-
if self._should_check_parameter(parameter, dataset):
479-
self.assertEqual(learner_params.get(parameter.name),
480-
parameter.get_value())
486+
for parameter in self.parameters:
487+
# Skip if the param isn't used for the given data type
488+
if self._should_check_parameter(parameter, dataset):
489+
self.assertEqual(
490+
self._get_param_value(self.widget.learner, parameter),
491+
parameter.get_value())
481492

482493
def test_parameters(self):
483494
"""Check learner and model for various values of all parameters"""
484-
485-
def get_value(learner, name):
486-
# Handle SKL and skl-like learners, and non-SKL learners
487-
if hasattr(learner, "params"):
488-
return learner.params.get(name)
489-
else:
490-
return getattr(learner, name)
491-
492495
# Test params on every valid dataset, since some attributes may apply
493496
# to only certain problem types
494497
for dataset in self.valid_datasets:
@@ -504,24 +507,22 @@ def get_value(learner, name):
504507
for value in parameter.values:
505508
parameter.set_value(value)
506509
self.widget.apply_button.button.click()
507-
param = get_value(self.widget.learner, parameter.name)
510+
param = self._get_param_value(self.widget.learner, parameter)
508511
self.assertEqual(
509512
param, parameter.get_value(),
510513
"Mismatching setting for parameter '%s'" % parameter)
511514
self.assertEqual(
512515
param, value,
513516
"Mismatching setting for parameter '%s'" % parameter)
514-
param = get_value(self.get_output("Learner"),
515-
parameter.name)
517+
param = self._get_param_value(self.get_output("Learner"), parameter)
516518
self.assertEqual(
517519
param, value,
518520
"Mismatching setting for parameter '%s'" % parameter)
519521

520522
if issubclass(self.widget.LEARNER, SklModel):
521523
model = self.get_output(self.model_name)
522524
if model is not None:
523-
self.assertEqual(get_value(model, parameter.name),
524-
value)
525+
self.assertEqual(self._get_param_value(model, parameter), value)
525526
self.assertFalse(self.widget.Error.active)
526527
else:
527528
self.assertTrue(self.widget.Error.active)

0 commit comments

Comments
 (0)