Skip to content

Commit 2869a55

Browse files
authored
Merge pull request #6921 from VesnaT/summary_of_fit
[ENH] Parameter Fitter: Basic implementation
2 parents 747b3b2 + 932fe14 commit 2869a55

File tree

16 files changed

+1834
-19
lines changed

16 files changed

+1834
-19
lines changed

Orange/base.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable
44
import re
55
import warnings
6-
from typing import Callable, Dict, Optional
6+
from typing import Callable, Optional, NamedTuple, Type
77

88
import numpy as np
99
import scipy
@@ -88,6 +88,13 @@ class Learner(ReprableWithPreprocessors):
8888
#: fitting the model
8989
preprocessors = ()
9090

91+
class FittedParameter(NamedTuple):
92+
name: str
93+
label: str
94+
type: Type
95+
min: Optional[int] = None
96+
max: Optional[int] = None
97+
9198
# Note: Do not use this class attribute.
9299
# It remains here for compatibility reasons.
93100
learner_adequacy_err_msg = ''
@@ -179,6 +186,10 @@ def active_preprocessors(self):
179186
self.preprocessors is not type(self).preprocessors):
180187
yield from type(self).preprocessors
181188

189+
@property
190+
def fitted_parameters(self) -> list:
191+
return []
192+
182193
# pylint: disable=no-self-use
183194
def incompatibility_reason(self, _: Domain) -> Optional[str]:
184195
"""Return None if a learner can fit domain or string explaining why it can not."""
@@ -883,5 +894,5 @@ def __init__(self, preprocessors=None, **kwargs):
883894
self.params = kwargs
884895

885896
@SklLearner.params.setter
886-
def params(self, values: Dict):
897+
def params(self, values: dict):
887898
self._params = values

Orange/evaluation/testing.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _identity(x):
2525

2626

2727
def _mp_worker(fold_i, train_data, test_data, learner_i, learner,
28-
store_models):
28+
store_models, suppresses_exceptions=True):
2929
predicted, probs, model, failed = None, None, None, False
3030
train_time, test_time = None, None
3131
try:
@@ -45,6 +45,8 @@ def _mp_worker(fold_i, train_data, test_data, learner_i, learner,
4545
test_time = time() - t0
4646
# Different models can fail at any time raising any exception
4747
except Exception as ex: # pylint: disable=broad-except
48+
if not suppresses_exceptions:
49+
raise ex
4850
failed = ex
4951
return _MpResults(fold_i, learner_i, store_models and model,
5052
failed, len(test_data), predicted, probs,
@@ -96,6 +98,7 @@ def __init__(self, data=None, *,
9698
row_indices=None, folds=None, score_by_folds=True,
9799
learners=None, models=None, failed=None,
98100
actual=None, predicted=None, probabilities=None,
101+
# pylint: disable=unused-argument
99102
store_data=None, store_models=None,
100103
train_time=None, test_time=None):
101104
"""
@@ -426,7 +429,8 @@ def fit(self, *args, **kwargs):
426429
DeprecationWarning)
427430
return self(*args, **kwargs)
428431

429-
def __call__(self, data, learners, preprocessor=None, *, callback=None):
432+
def __call__(self, data, learners, preprocessor=None, *, callback=None,
433+
suppresses_exceptions=True):
430434
"""
431435
Args:
432436
data (Orange.data.Table): data to be used (usually split) into
@@ -435,6 +439,7 @@ def __call__(self, data, learners, preprocessor=None, *, callback=None):
435439
preprocessor (Orange.preprocess.Preprocess): preprocessor applied
436440
on training data
437441
callback (Callable): a function called to notify about the progress
442+
suppresses_exceptions (bool): suppress the exceptions if True
438443
439444
Returns:
440445
results (Result): results of testing
@@ -457,7 +462,10 @@ def __call__(self, data, learners, preprocessor=None, *, callback=None):
457462
part_results = []
458463
parts = np.linspace(.0, .99, len(learners) * len(indices) + 1)[1:]
459464
for progress, part in zip(parts, args_iter):
460-
part_results.append(_mp_worker(*(part + ())))
465+
part_results.append(
466+
_mp_worker(*(part + ()),
467+
suppresses_exceptions=suppresses_exceptions)
468+
)
461469
callback(progress)
462470
callback(1)
463471

@@ -723,7 +731,7 @@ def __new__(cls, data=None, test_data=None, learners=None,
723731
test_data=test_data, **kwargs)
724732

725733
def __call__(self, data, test_data, learners, preprocessor=None,
726-
*, callback=None):
734+
*, callback=None, suppresses_exceptions=True):
727735
"""
728736
Args:
729737
data (Orange.data.Table): training data
@@ -732,6 +740,7 @@ def __call__(self, data, test_data, learners, preprocessor=None,
732740
preprocessor (Orange.preprocess.Preprocess): preprocessor applied
733741
on training data
734742
callback (Callable): a function called to notify about the progress
743+
suppresses_exceptions (bool): suppress the exceptions if True
735744
736745
Returns:
737746
results (Result): results of testing
@@ -746,7 +755,7 @@ def __call__(self, data, test_data, learners, preprocessor=None,
746755
for (learner_i, learner) in enumerate(learners):
747756
part_results.append(
748757
_mp_worker(0, train_data, test_data, learner_i, learner,
749-
self.store_models))
758+
self.store_models, suppresses_exceptions))
750759
callback((learner_i + 1) / len(learners))
751760
callback(1)
752761

@@ -778,13 +787,14 @@ def __new__(cls, data=None, learners=None, preprocessor=None, **kwargs):
778787
**kwargs)
779788

780789
def __call__(self, data, learners, preprocessor=None, *, callback=None,
781-
**kwargs):
790+
suppresses_exceptions=True, **kwargs):
782791
kwargs.setdefault("test_data", data)
783792
# if kwargs contains anything besides test_data, this will be detected
784793
# (and complained about) by super().__call__
785794
return super().__call__(
786795
data=data, learners=learners, preprocessor=preprocessor,
787-
callback=callback, **kwargs)
796+
callback=callback, suppresses_exceptions=suppresses_exceptions,
797+
**kwargs)
788798

789799

790800
def sample(table, n=0.7, stratified=False, replace=False,

Orange/modelling/randomforest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from Orange.base import RandomForestModel
1+
from Orange.base import RandomForestModel, Learner
22
from Orange.classification import RandomForestLearner as RFClassification
33
from Orange.data import Variable
44
from Orange.modelling import SklFitter
@@ -24,3 +24,8 @@ class RandomForestLearner(SklFitter, _FeatureScorerMixin):
2424
'regression': RFRegression}
2525

2626
__returns__ = RandomForestModel
27+
28+
@property
29+
def fitted_parameters(self) -> list[Learner.FittedParameter]:
30+
return [self.FittedParameter("n_estimators", "Number of trees",
31+
int, 1, None)]

Orange/regression/pls.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Tuple
2-
31
import numpy as np
42
import scipy.stats as ss
53
import sklearn.cross_decomposition as skl_pls
64
from sklearn.preprocessing import StandardScaler
75

6+
from Orange.base import Learner
87
from Orange.data import Table, Domain, Variable, \
98
ContinuousVariable, StringVariable
109
from Orange.data.util import get_unique_names, SharedComputeValue
@@ -163,11 +162,11 @@ def coefficients_table(self):
163162
return coef_table
164163

165164
@property
166-
def rotations(self) -> Tuple[np.ndarray, np.ndarray]:
165+
def rotations(self) -> tuple[np.ndarray, np.ndarray]:
167166
return self.skl_model.x_rotations_, self.skl_model.y_rotations_
168167

169168
@property
170-
def loadings(self) -> Tuple[np.ndarray, np.ndarray]:
169+
def loadings(self) -> tuple[np.ndarray, np.ndarray]:
171170
return self.skl_model.x_loadings_, self.skl_model.y_loadings_
172171

173172
def residuals_normal_probability(self, data: Table) -> Table:
@@ -256,6 +255,11 @@ def incompatibility_reason(self, domain):
256255
reason = "Only numeric target variables expected."
257256
return reason
258257

258+
@property
259+
def fitted_parameters(self) -> list[Learner.FittedParameter]:
260+
return [self.FittedParameter("n_components", "Components",
261+
int, 1, None)]
262+
259263

260264
if __name__ == '__main__':
261265
import Orange

Orange/regression/tests/test_pls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ def table(rows, attr, variables):
2121

2222

2323
class TestPLSRegressionLearner(unittest.TestCase):
24+
def test_fitted_parameters(self):
25+
fitted_parameters = PLSRegressionLearner().fitted_parameters
26+
self.assertIsInstance(fitted_parameters, list)
27+
self.assertEqual(len(fitted_parameters), 1)
28+
2429
def test_allow_y_dim(self):
2530
""" The current PLS version allows only a single Y dimension. """
2631
learner = PLSRegressionLearner(n_components=2)
Lines changed: 26 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)