Skip to content

Commit 24638c0

Browse files
committed
Parameter Fitter: Retain settings on input
1 parent 8376e3b commit 24638c0

File tree

9 files changed

+127
-90
lines changed

9 files changed

+127
-90
lines changed

Orange/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable
44
import re
55
import warnings
6-
from typing import Callable, Optional, NamedTuple, Union, Type
6+
from typing import Callable, Optional, NamedTuple, Type
77

88
import numpy as np
99
import scipy
@@ -186,8 +186,8 @@ def active_preprocessors(self):
186186
self.preprocessors is not type(self).preprocessors):
187187
yield from type(self).preprocessors
188188

189-
# declared for derived classes, pylint: disable=unused-argument
190-
def fitted_parameters(self, problem_type: Union[str, Table, Domain]) -> list:
189+
@property
190+
def fitted_parameters(self) -> list:
191191
return []
192192

193193
# pylint: disable=no-self-use

Orange/classification/random_forest.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sklearn.ensemble as skl_ensemble
22

3-
from Orange.base import RandomForestModel, Learner
3+
from Orange.base import RandomForestModel
44
from Orange.classification import SklLearner, SklModel
55
from Orange.classification.tree import SklTreeClassifier
66
from Orange.data import Variable, DiscreteVariable
@@ -58,7 +58,3 @@ def __init__(self,
5858
preprocessors=None):
5959
super().__init__(preprocessors=preprocessors)
6060
self.params = vars()
61-
62-
def fitted_parameters(self, _) -> list[Learner.FittedParameter]:
63-
return [self.FittedParameter("n_estimators", "Number of trees",
64-
int, 1, None)]

Orange/evaluation/testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(self, data=None, *,
9898
row_indices=None, folds=None, score_by_folds=True,
9999
learners=None, models=None, failed=None,
100100
actual=None, predicted=None, probabilities=None,
101+
# pylint: disable=unused-argument
101102
store_data=None, store_models=None,
102103
train_time=None, test_time=None):
103104
"""

Orange/modelling/randomforest.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from typing import Union
2-
31
from Orange.base import RandomForestModel, Learner
42
from Orange.classification import RandomForestLearner as RFClassification
5-
from Orange.data import Variable, Domain, Table
3+
from Orange.data import Variable
64
from Orange.modelling import SklFitter
75
from Orange.preprocess.score import LearnerScorer
86
from Orange.regression import RandomForestRegressionLearner as RFRegression
@@ -27,8 +25,7 @@ class RandomForestLearner(SklFitter, _FeatureScorerMixin):
2725

2826
__returns__ = RandomForestModel
2927

30-
def fitted_parameters(
31-
self,
32-
problem_type: Union[str, Table, Domain]
33-
) -> list[Learner.FittedParameter]:
34-
return self.get_learner(problem_type).fitted_parameters(problem_type)
28+
@property
29+
def fitted_parameters(self) -> list[Learner.FittedParameter]:
30+
return [self.FittedParameter("n_estimators", "Number of trees",
31+
int, 1, None)]

Orange/regression/pls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def incompatibility_reason(self, domain):
255255
reason = "Only numeric target variables expected."
256256
return reason
257257

258-
def fitted_parameters(self, _) -> list[Learner.FittedParameter]:
258+
@property
259+
def fitted_parameters(self) -> list[Learner.FittedParameter]:
259260
return [self.FittedParameter("n_components", "Components",
260261
int, 1, None)]
261262

Orange/regression/random_forest.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sklearn.ensemble as skl_ensemble
22

3-
from Orange.base import RandomForestModel, Learner
3+
from Orange.base import RandomForestModel
44
from Orange.data import Variable, ContinuousVariable
55
from Orange.preprocess.score import LearnerScorer
66
from Orange.regression import SklLearner, SklModel
@@ -57,7 +57,3 @@ def __init__(self,
5757
preprocessors=None):
5858
super().__init__(preprocessors=preprocessors)
5959
self.params = vars()
60-
61-
def fitted_parameters(self, _) -> list[Learner.FittedParameter]:
62-
return [self.FittedParameter("n_estimators", "Number of trees",
63-
int, 1, None)]

Orange/regression/tests/test_pls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def table(rows, attr, variables):
2222

2323
class TestPLSRegressionLearner(unittest.TestCase):
2424
def test_fitted_parameters(self):
25-
fitted_parameters = PLSRegressionLearner().fitted_parameters(None)
25+
fitted_parameters = PLSRegressionLearner().fitted_parameters
2626
self.assertIsInstance(fitted_parameters, list)
2727
self.assertEqual(len(fitted_parameters), 1)
2828

Orange/widgets/evaluate/owparameterfitter.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
N_FOLD = 7
3636
MIN_MAX_SPIN = 100000
3737
ScoreType = tuple[int, tuple[float, float]]
38-
# scores, score name, tick label
38+
# scores, score name, label
3939
FitterResults = tuple[list[ScoreType], str, str]
4040

4141

@@ -173,6 +173,7 @@ def clear_all(self):
173173
self.__bar_item_tr = None
174174
self.__bar_item_cv = None
175175
self.__data = None
176+
self.setLabel(axis="bottom", text=None)
176177
self.setLabel(axis="left", text=None)
177178
self.getAxis("bottom").setTicks(None)
178179

@@ -320,12 +321,12 @@ class Inputs:
320321
DEFAULT_PARAMETER_INDEX = 0
321322
DEFAULT_MINIMUM = 1
322323
DEFAULT_MAXIMUM = 9
323-
parameter_index = Setting(DEFAULT_PARAMETER_INDEX, schema_only=True)
324+
parameter_index = Setting(DEFAULT_PARAMETER_INDEX)
324325
FROM_RANGE, MANUAL = range(2)
325326
type: int = Setting(FROM_RANGE)
326-
minimum: int = Setting(DEFAULT_MINIMUM, schema_only=True)
327-
maximum: int = Setting(DEFAULT_MAXIMUM, schema_only=True)
328-
manual_steps: str = Setting("", schema_only=True)
327+
minimum: int = Setting(DEFAULT_MINIMUM)
328+
maximum: int = Setting(DEFAULT_MAXIMUM)
329+
manual_steps: str = Setting("")
329330
auto_commit = Setting(True)
330331

331332
class Error(OWWidget.Error):
@@ -345,13 +346,10 @@ def __init__(self):
345346
self._data: Optional[Table] = None
346347
self._learner: Optional[Learner] = None
347348
self.__parameters_model = QStandardItemModel()
348-
349-
self.__pending_parameter_index = self.parameter_index \
350-
if self.parameter_index != self.DEFAULT_PARAMETER_INDEX else None
351-
self.__pending_minimum = self.minimum \
352-
if self.minimum != self.DEFAULT_MINIMUM else None
353-
self.__pending_maximum = self.maximum \
354-
if self.maximum != self.DEFAULT_MAXIMUM else None
349+
self.__initialize_settings = \
350+
self.parameter_index == self.DEFAULT_PARAMETER_INDEX and \
351+
self.minimum == self.DEFAULT_MINIMUM and \
352+
self.maximum == self.DEFAULT_MAXIMUM
355353

356354
self.setup_gui()
357355
VisualSettingsDialog(
@@ -418,10 +416,13 @@ def _():
418416

419417
gui.auto_apply(self.buttonsArea, self, "auto_commit")
420418

419+
self._update_preview()
420+
421421
def __on_type_changed(self):
422422
self._settings_changed()
423423

424424
def __on_parameter_changed(self):
425+
self.__initialize_settings = True
425426
self._set_range_controls()
426427
self._settings_changed()
427428

@@ -439,19 +440,16 @@ def _settings_changed(self):
439440

440441
@property
441442
def fitted_parameters(self) -> list:
442-
if not self._learner \
443-
or isinstance(self._learner, Fitter) and not self._data:
443+
if not self._learner:
444444
return []
445-
return self._learner.fitted_parameters(self._data)
445+
return self._learner.fitted_parameters
446446

447447
@property
448448
def initial_parameters(self) -> dict:
449449
if not self._learner:
450450
return {}
451451
if isinstance(self._learner, Fitter):
452-
if not self._data:
453-
return {}
454-
return self._learner.get_params(self._data)
452+
return self._learner.get_params(self._data or "classification")
455453
return self._learner.params
456454

457455
@property
@@ -495,38 +493,32 @@ def _steps_from_manual(self) -> tuple[int, ...]:
495493
@Inputs.data
496494
@check_multiple_targets_input
497495
def set_data(self, data: Optional[Table]):
496+
self.Error.not_enough_data.clear()
497+
self.Error.missing_target.clear()
498498
self._data = data
499+
if self._data and len(self._data) < N_FOLD:
500+
self.Error.not_enough_data()
501+
self._data = None
502+
if self._data and len(self._data.domain.class_vars) < 1:
503+
self.Error.missing_target()
504+
self._data = None
499505

500506
@Inputs.learner
501507
def set_learner(self, learner: Optional[Learner]):
508+
if self._learner:
509+
self.__initialize_settings = \
510+
not isinstance(self._learner, type(learner))
502511
self._learner = learner
503512

504513
def handleNewSignals(self):
505514
self.Warning.clear()
506515
self.Error.unknown_err.clear()
507-
self.Error.not_enough_data.clear()
508516
self.Error.incompatible_learner.clear()
509517
self.Error.manual_steps_error.clear()
510518
self.Error.min_max_error.clear()
511-
self.Error.missing_target.clear()
512519
self.clear()
513520

514-
if self._data is None or self._learner is None:
515-
return
516-
517-
if self._data and len(self._data) < N_FOLD:
518-
self.Error.not_enough_data()
519-
self._data = None
520-
return
521-
522-
if self._data and len(self._data.domain.class_vars) < 1:
523-
self.Error.missing_target()
524-
self._data = None
525-
return
526-
527-
reason = self._learner.incompatibility_reason(self._data.domain)
528-
if reason:
529-
self.Error.incompatible_learner(reason)
521+
if self._learner is None:
530522
return
531523

532524
for param in self.fitted_parameters:
@@ -535,19 +527,22 @@ def handleNewSignals(self):
535527
if not self.fitted_parameters:
536528
self.Warning.no_parameters(self._learner.name)
537529
else:
538-
if self.__pending_parameter_index is not None:
539-
self.parameter_index = self.__pending_parameter_index
530+
if self.__initialize_settings:
531+
self.parameter_index = 0
532+
else:
540533
self.__combo.setCurrentIndex(self.parameter_index)
541-
self.__pending_parameter_index = None
542534
self._set_range_controls()
543-
if self.__pending_minimum is not None:
544-
self.minimum = self.__pending_minimum
545-
self.__pending_minimum = None
546-
if self.__pending_maximum is not None:
547-
self.maximum = self.__pending_maximum
548-
self.__pending_maximum = None
549535

550536
self._update_preview()
537+
538+
if self._data is None:
539+
return
540+
541+
reason = self._learner.incompatibility_reason(self._data.domain)
542+
if reason:
543+
self.Error.incompatible_learner(reason)
544+
return
545+
551546
self.commit.now()
552547

553548
def _set_range_controls(self):
@@ -561,19 +556,24 @@ def _set_range_controls(self):
561556
if param.min is not None:
562557
self.__spin_min.setMinimum(param.min)
563558
self.__spin_max.setMinimum(param.min)
564-
self.minimum = param.min
559+
if self.__initialize_settings:
560+
self.minimum = param.min
565561
else:
566562
self.__spin_min.setMinimum(-MIN_MAX_SPIN)
567563
self.__spin_max.setMinimum(-MIN_MAX_SPIN)
568-
self.minimum = self.initial_parameters[param.name]
564+
if self.__initialize_settings:
565+
self.minimum = self.initial_parameters[param.name]
569566
if param.max is not None:
570567
self.__spin_min.setMaximum(param.max)
571568
self.__spin_max.setMaximum(param.max)
572-
self.maximum = param.max
569+
if self.__initialize_settings:
570+
self.maximum = param.max
573571
else:
574572
self.__spin_min.setMaximum(MIN_MAX_SPIN)
575573
self.__spin_max.setMaximum(MIN_MAX_SPIN)
576-
self.maximum = self.initial_parameters[param.name]
574+
if self.__initialize_settings:
575+
self.maximum = self.initial_parameters[param.name]
576+
self.__initialize_settings = False
577577

578578
tip = "Enter a list of values"
579579
if param.min is not None:

0 commit comments

Comments
 (0)