Skip to content

Commit 7e72e15

Browse files
committed
Parameter Fitter: Type annotations
1 parent a8f1806 commit 7e72e15

File tree

6 files changed

+37
-38
lines changed

6 files changed

+37
-38
lines changed

Orange/base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import inspect
22
import itertools
3-
from collections import namedtuple
43
from collections.abc import Iterable
54
import re
65
import warnings
7-
from typing import Callable, Dict, Optional, List
6+
from types import NoneType
7+
from typing import Callable, Optional, NamedTuple, Union, Type
88

99
import numpy as np
1010
import scipy
@@ -88,9 +88,15 @@ class Learner(ReprableWithPreprocessors):
8888
#: A sequence of data preprocessors to apply on data prior to
8989
#: fitting the model
9090
preprocessors = ()
91-
FittedParameter = namedtuple(
92-
"FittedParameter",
93-
["parameter_name", "label", "tick_label", "type", "min", "max"]
91+
FittedParameter = NamedTuple(
92+
"FittedParameter", [
93+
("name", str),
94+
("label", str),
95+
("tick_label", str),
96+
("type", Type),
97+
("min", Union[int, NoneType]),
98+
("max", Union[int, NoneType]),
99+
]
94100
)
95101

96102
# Note: Do not use this class attribute.
@@ -184,7 +190,7 @@ def active_preprocessors(self):
184190
self.preprocessors is not type(self).preprocessors):
185191
yield from type(self).preprocessors
186192

187-
def fitted_parameters(self, *args, **kwargs) -> List:
193+
def fitted_parameters(self, *args, **kwargs) -> list:
188194
return []
189195

190196
# pylint: disable=no-self-use
@@ -891,5 +897,5 @@ def __init__(self, preprocessors=None, **kwargs):
891897
self.params = kwargs
892898

893899
@SklLearner.params.setter
894-
def params(self, values: Dict):
900+
def params(self, values: dict):
895901
self._params = values

Orange/classification/random_forest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import List
2-
31
import sklearn.ensemble as skl_ensemble
42

53
from Orange.base import RandomForestModel, Learner
@@ -61,6 +59,6 @@ def __init__(self,
6159
super().__init__(preprocessors=preprocessors)
6260
self.params = vars()
6361

64-
def fitted_parameters(self) -> List[Learner.FittedParameter]:
62+
def fitted_parameters(self) -> list[Learner.FittedParameter]:
6563
return [self.FittedParameter("n_estimators", "Number of trees",
6664
"Trees", int, 1, None)]

Orange/modelling/randomforest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Union
1+
from typing import Union
22

33
from Orange.base import RandomForestModel, Learner
44
from Orange.classification import RandomForestLearner as RFClassification
@@ -30,5 +30,5 @@ class RandomForestLearner(SklFitter, _FeatureScorerMixin):
3030
def fitted_parameters(
3131
self,
3232
problem_type: Union[str, Table, Domain]
33-
) -> List[Learner.FittedParameter]:
33+
) -> list[Learner.FittedParameter]:
3434
return self.get_learner(problem_type).fitted_parameters()

Orange/regression/pls.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Tuple, List
2-
31
import numpy as np
42
import scipy.stats as ss
53
import sklearn.cross_decomposition as skl_pls
@@ -164,11 +162,11 @@ def coefficients_table(self):
164162
return coef_table
165163

166164
@property
167-
def rotations(self) -> Tuple[np.ndarray, np.ndarray]:
165+
def rotations(self) -> tuple[np.ndarray, np.ndarray]:
168166
return self.skl_model.x_rotations_, self.skl_model.y_rotations_
169167

170168
@property
171-
def loadings(self) -> Tuple[np.ndarray, np.ndarray]:
169+
def loadings(self) -> tuple[np.ndarray, np.ndarray]:
172170
return self.skl_model.x_loadings_, self.skl_model.y_loadings_
173171

174172
def residuals_normal_probability(self, data: Table) -> Table:
@@ -257,7 +255,7 @@ def incompatibility_reason(self, domain):
257255
reason = "Only numeric target variables expected."
258256
return reason
259257

260-
def fitted_parameters(self) -> List[Learner.FittedParameter]:
258+
def fitted_parameters(self) -> list[Learner.FittedParameter]:
261259
return [self.FittedParameter("n_components", "Number of components",
262260
"Comp", int, 1, None)]
263261

Orange/regression/random_forest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import List
2-
31
import sklearn.ensemble as skl_ensemble
42

53
from Orange.base import RandomForestModel, Learner
@@ -60,6 +58,6 @@ def __init__(self,
6058
super().__init__(preprocessors=preprocessors)
6159
self.params = vars()
6260

63-
def fitted_parameters(self) -> List[Learner.FittedParameter]:
61+
def fitted_parameters(self) -> list[Learner.FittedParameter]:
6462
return [self.FittedParameter("n_estimators", "Number of trees",
6563
"Trees", int, 1, None)]

Orange/widgets/evaluate/owparameterfitter.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple, Callable, List, Dict, Iterable, Sized
1+
from typing import Optional, Callable, Iterable, Sized
22

33
import numpy as np
44
from AnyQt.QtCore import QPointF, Qt
@@ -13,7 +13,6 @@
1313

1414
from Orange.base import Learner
1515
from Orange.data import Table
16-
from Orange.data.table import DomainTransformationError
1716
from Orange.evaluation import CrossValidation, TestOnTrainingData, Results
1817
from Orange.evaluation.scoring import Score, AUC, R2
1918
from Orange.modelling import Fitter
@@ -32,16 +31,16 @@
3231

3332
N_FOLD = 7
3433
MIN_MAX_SPIN = 100000
35-
ScoreType = Tuple[int, Tuple[float, float]]
34+
ScoreType = tuple[int, tuple[float, float]]
3635
# scores, score name, tick label
37-
FitterResults = Tuple[List[ScoreType], str, str]
36+
FitterResults = tuple[list[ScoreType], str, str]
3837

3938

4039
def _validate(
4140
data: Table,
4241
learner: Learner,
4342
scorer: type[Score]
44-
) -> Tuple[float, float]:
43+
) -> tuple[float, float]:
4544
# dummy call - Validation would silence the exceptions
4645
learner(data)
4746

@@ -55,18 +54,18 @@ def _search(
5554
data: Table,
5655
learner: Learner,
5756
fitted_parameter_props: Learner.FittedParameter,
58-
initial_parameters: Dict,
57+
initial_parameters: dict,
5958
steps: Sized,
6059
progress_callback: Callable = dummy_callback
6160
) -> FitterResults:
6261
progress_callback(0, "Calculating...")
6362
scores = []
6463
scorer = AUC if data.domain.has_discrete_class else R2
65-
parameter_name = fitted_parameter_props.parameter_name
64+
name = fitted_parameter_props.name
6665
for i, value in enumerate(steps):
6766
progress_callback(i / len(steps))
6867
params = initial_parameters.copy()
69-
params[parameter_name] = value
68+
params[name] = value
7069
result = _validate(data, type(learner)(**params), scorer)
7170
scores.append((value, result))
7271
return scores, scorer.name, fitted_parameter_props.tick_label
@@ -76,7 +75,7 @@ def run(
7675
data: Table,
7776
learner: Learner,
7877
fitted_parameter_props: Learner.FittedParameter,
79-
initial_parameters: Dict,
78+
initial_parameters: dict,
8079
steps: Sized,
8180
state: TaskState
8281
) -> FitterResults:
@@ -97,7 +96,7 @@ class ParameterSetter(CommonParameterSetter):
9796
DEFAULT_ALPHA_GRID, DEFAULT_SHOW_GRID = 80, True
9897

9998
def __init__(self, master):
100-
self.grid_settings: Dict = None
99+
self.grid_settings: dict = None
101100
self.master: FitterPlot = master
102101
super().__init__()
103102

@@ -148,7 +147,7 @@ def __init__(self):
148147
super().__init__(enableMenu=False)
149148
self.__bar_item_tr: pg.BarGraphItem = None
150149
self.__bar_item_cv: pg.BarGraphItem = None
151-
self.__data: List[ScoreType] = None
150+
self.__data: list[ScoreType] = None
152151
self.legend = self._create_legend()
153152
self.parameter_setter = ParameterSetter(self)
154153
self.setMouseEnabled(False, False)
@@ -177,7 +176,7 @@ def clear_all(self):
177176

178177
def set_data(
179178
self,
180-
scores: List[ScoreType],
179+
scores: list[ScoreType],
181180
score_name: str,
182181
tick_name: str
183182
):
@@ -243,8 +242,8 @@ def __get_index_at(self, point: QPointF) -> Optional[int]:
243242
x = point.x()
244243
index = round(x)
245244
# pylint: disable=unsubscriptable-object
246-
heights_tr: List = self.__bar_item_tr.opts["height"]
247-
heights_cv: List = self.__bar_item_cv.opts["height"]
245+
heights_tr: list = self.__bar_item_tr.opts["height"]
246+
heights_cv: list = self.__bar_item_cv.opts["height"]
248247
if 0 <= index < len(heights_tr) and abs(index - x) <= self.BAR_WIDTH:
249248
if index > x and 0 <= point.y() <= heights_tr[index]:
250249
return index
@@ -356,15 +355,15 @@ def __on_setting_changed(self):
356355
self.commit.deferred()
357356

358357
@property
359-
def fitted_parameters(self) -> List:
358+
def fitted_parameters(self) -> list:
360359
if not self._learner or not self._data:
361360
return []
362361
return self._learner.fitted_parameters(self._data) \
363362
if isinstance(self._learner, Fitter) \
364363
else self._learner.fitted_parameters()
365364

366365
@property
367-
def initial_parameters(self) -> Dict:
366+
def initial_parameters(self) -> dict:
368367
if not self._learner or not self._data:
369368
return {}
370369
return self._learner.get_params(self._data) \
@@ -444,15 +443,15 @@ def _set_range_controls(self):
444443
else:
445444
self.__spin_min.setMinimum(-MIN_MAX_SPIN)
446445
self.__spin_max.setMinimum(-MIN_MAX_SPIN)
447-
self.minimum = self.initial_parameters[param.parameter_name]
446+
self.minimum = self.initial_parameters[param.name]
448447
if param.max is not None:
449448
self.__spin_min.setMaximum(param.max)
450449
self.__spin_max.setMaximum(param.max)
451450
self.maximum = param.max
452451
else:
453452
self.__spin_min.setMaximum(MIN_MAX_SPIN)
454453
self.__spin_max.setMaximum(MIN_MAX_SPIN)
455-
self.maximum = self.initial_parameters[param.parameter_name]
454+
self.maximum = self.initial_parameters[param.name]
456455

457456
def _update_preview(self):
458457
self.preview = str(list(self.steps))

0 commit comments

Comments
 (0)