Skip to content

Commit a5d9023

Browse files
committed
Parameter Fitter: Manual steps
1 parent 450b317 commit a5d9023

File tree

2 files changed

+88
-16
lines changed

2 files changed

+88
-16
lines changed

Orange/widgets/evaluate/owparameterfitter.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Optional, Tuple, Callable, List, Dict
1+
from typing import Optional, Tuple, Callable, List, Dict, Iterable
22

33
import numpy as np
4-
from AnyQt.QtCore import QPointF
4+
from AnyQt.QtCore import QPointF, Qt
55
from AnyQt.QtGui import QStandardItemModel, QStandardItem
66
from AnyQt.QtWidgets import QGraphicsSceneHelpEvent, QToolTip
77

@@ -55,17 +55,11 @@ def _search(
5555
learner: Learner,
5656
fitted_parameter_props: Learner.FittedParameter,
5757
initial_parameters: Dict,
58-
minimum: int,
59-
maximum: int,
58+
steps: Iterable,
6059
progress_callback: Callable = dummy_callback
6160
) -> FitterResults:
6261
progress_callback(0, "Calculating...")
6362
scores = []
64-
step = 1
65-
if maximum - minimum > 0:
66-
exp = int(np.ceil(np.log10(maximum - minimum + 1))) - 1
67-
step = int(10 ** exp)
68-
steps = range(minimum, maximum + step, step)
6963
scorer = AUC if data.domain.has_discrete_class else R2
7064
parameter_name = fitted_parameter_props.parameter_name
7165
for i, value in enumerate(steps):
@@ -82,8 +76,7 @@ def run(
8276
learner: Learner,
8377
fitted_parameter_props: Learner.FittedParameter,
8478
initial_parameters: Dict,
85-
minimum: int,
86-
maximum: int,
79+
steps: Iterable,
8780
state: TaskState
8881
) -> FitterResults:
8982
def callback(i: float, status: str = ""):
@@ -95,7 +88,7 @@ def callback(i: float, status: str = ""):
9588
raise Exception
9689

9790
return _search(data, learner, fitted_parameter_props, initial_parameters,
98-
minimum, maximum, callback)
91+
steps, callback)
9992

10093

10194
class ParameterSetter(CommonParameterSetter):
@@ -280,8 +273,11 @@ class Inputs:
280273
DEFAULT_MINIMUM = 1
281274
DEFAULT_MAXIMUM = 9
282275
parameter_index = Setting(DEFAULT_PARAMETER_INDEX, schema_only=True)
276+
FROM_RANGE, MANUAL = range(2)
277+
type = Setting(FROM_RANGE)
283278
minimum = Setting(DEFAULT_MINIMUM, schema_only=True)
284279
maximum = Setting(DEFAULT_MAXIMUM, schema_only=True)
280+
manual_steps = Setting("", schema_only=True)
285281
auto_commit = Setting(True)
286282

287283
class Error(OWWidget.Error):
@@ -327,17 +323,36 @@ def _add_controls(self):
327323
self.__combo = gui.comboBox(box, self, "parameter_index",
328324
model=self.__parameters_model,
329325
callback=self.__on_parameter_changed)
330-
hbox = gui.hBox(box)
326+
327+
buttons = gui.radioButtons(box, self, "type",
328+
callback=self.__on_setting_changed)
329+
330+
gui.appendRadioButton(buttons, "Range")
331+
hbox = gui.indentedBox(buttons, 20, orientation=Qt.Horizontal)
331332
kw = {"minv": -MIN_MAX_SPIN, "maxv": MIN_MAX_SPIN,
332-
"callback": self.commit.deferred}
333+
"callback": self.__on_setting_changed}
333334
self.__spin_min = gui.spin(hbox, self, "minimum", label="Min:", **kw)
334335
self.__spin_max = gui.spin(hbox, self, "maximum", label="Max:", **kw)
336+
337+
gui.appendRadioButton(buttons, "Manual")
338+
hbox = gui.indentedBox(box, 20, orientation=Qt.Horizontal)
339+
gui.lineEdit(hbox, self, "manual_steps", placeholderText="10, 20, 30",
340+
callback=self.__on_setting_changed)
341+
342+
box = gui.vBox(self.controlArea, "Steps preview")
343+
self.preview = ""
344+
gui.label(box, self, "%(preview)s", wordWrap=True)
345+
335346
gui.rubber(self.controlArea)
336347

337348
gui.auto_apply(self.buttonsArea, self, "auto_commit")
338349

339350
def __on_parameter_changed(self):
340351
self._set_range_controls()
352+
self.__on_setting_changed()
353+
354+
def __on_setting_changed(self):
355+
self._update_preview()
341356
self.commit.deferred()
342357

343358
@property
@@ -355,6 +370,21 @@ def initial_parameters(self) -> Dict:
355370
return self._learner.get_params(self._data) \
356371
if isinstance(self._learner, Fitter) else self._learner.params
357372

373+
@property
374+
def steps(self) -> Iterable[int]:
375+
if self.type == self.FROM_RANGE:
376+
step = 1
377+
diff = self.maximum - self.minimum
378+
if diff > 0:
379+
exp = int(np.ceil(np.log10(diff + 1))) - 1
380+
step = int(10 ** exp)
381+
return range(self.minimum, self.maximum + step, step)
382+
else:
383+
try:
384+
return [int(s) for s in self.manual_steps.split(",")]
385+
except ValueError:
386+
return []
387+
358388
@Inputs.data
359389
@check_multiple_targets_input
360390
def set_data(self, data: Table):
@@ -400,6 +430,7 @@ def handleNewSignals(self):
400430
self.maximum = self.__pending_maximum
401431
self.__pending_maximum = None
402432

433+
self._update_preview()
403434
self.commit.now()
404435

405436
def _set_range_controls(self):
@@ -424,6 +455,9 @@ def _set_range_controls(self):
424455
self.__spin_max.setMaximum(MIN_MAX_SPIN)
425456
self.maximum = self.initial_parameters[param.parameter_name]
426457

458+
def _update_preview(self):
459+
self.preview = str(list(self.steps))
460+
427461
def clear(self):
428462
self.cancel()
429463
self.graph.clear_all()
@@ -437,8 +471,7 @@ def commit(self):
437471
self.graph.clear_all()
438472
self.start(run, self._data, self._learner,
439473
self.fitted_parameters[self.parameter_index],
440-
self.initial_parameters,
441-
self.minimum, self.maximum)
474+
self.initial_parameters, self.steps)
442475

443476
def on_done(self, result: FitterResults):
444477
self.graph.set_data(*result)

Orange/widgets/evaluate/tests/test_owparameterfitter.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,45 @@ def test_random_forest(self):
6767
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
6868
self.assertFalse(self.widget.Error.incompatible_learner.is_shown())
6969

70+
def test_plot(self):
71+
self.send_signal(self.widget.Inputs.data, self._housing)
72+
self.send_signal(self.widget.Inputs.learner, self._pls)
73+
self.wait_until_finished()
74+
75+
x = self.widget.graph._FitterPlot__bar_item_tr.opts["x"]
76+
self.assertEqual(list(x), [-0.2, 0.8])
77+
x = self.widget.graph._FitterPlot__bar_item_cv.opts["x"]
78+
self.assertEqual(list(x), [0.2, 1.2])
79+
80+
def test_manual_steps(self):
81+
self.send_signal(self.widget.Inputs.data, self._housing)
82+
self.send_signal(self.widget.Inputs.learner, self._pls)
83+
self.wait_until_finished()
84+
85+
self.widget.controls.manual_steps.setText("1, 2, 3")
86+
self.widget.controls.type.buttons[1].click()
87+
self.wait_until_finished()
88+
89+
x = self.widget.graph._FitterPlot__bar_item_tr.opts["x"]
90+
self.assertEqual(list(x), [-0.2, 0.8, 1.8])
91+
x = self.widget.graph._FitterPlot__bar_item_cv.opts["x"]
92+
self.assertEqual(list(x), [0.2, 1.2, 2.2])
93+
94+
def test_steps_preview(self):
95+
self.send_signal(self.widget.Inputs.data, self._housing)
96+
self.send_signal(self.widget.Inputs.learner, self._pls)
97+
self.wait_until_finished()
98+
self.assertEqual(self.widget.preview, "[1, 2]")
99+
100+
self.widget.controls.type.buttons[1].click()
101+
self.wait_until_finished()
102+
self.assertEqual(self.widget.preview, "[]")
103+
104+
self.widget.controls.manual_steps.setText("10, 15, 20, 25")
105+
self.widget.controls.type.buttons[1].click()
106+
self.wait_until_finished()
107+
self.assertEqual(self.widget.preview, "[10, 15, 20, 25]")
108+
70109

71110
if __name__ == "__main__":
72111
unittest.main()

0 commit comments

Comments
 (0)