Skip to content

Commit fb41ee3

Browse files
committed
Fixing binary scores in Test and Score widget to average in case of no target selected
1 parent 5043829 commit fb41ee3

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

Orange/widgets/evaluate/owtestlearners.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def _update_stats_model(self):
567567

568568
# Cell variable is used immediatelly, it's not stored
569569
# pylint: disable=cell-var-from-loop
570-
stats = [Try(scorer_caller(scorer, ovr_results))
570+
stats = [Try(scorer_caller(scorer, ovr_results, target=1))
571571
for scorer in self.scorers]
572572
else:
573573
stats = None
@@ -950,9 +950,9 @@ def onDeleteWidget(self):
950950
super().onDeleteWidget()
951951

952952

953-
def scorer_caller(scorer, ovr_results):
953+
def scorer_caller(scorer, ovr_results, target=None):
954954
if scorer.is_binary:
955-
return lambda: scorer(ovr_results, target=1, average='weighted')
955+
return lambda: scorer(ovr_results, target=target, average='weighted')
956956
else:
957957
return lambda: scorer(ovr_results)
958958

Orange/widgets/evaluate/tests/test_owtestlearners.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from AnyQt.QtWidgets import QMenu
88
from AnyQt.QtCore import QPoint
99

10-
from Orange.classification import MajorityLearner
10+
from Orange.classification import MajorityLearner, LogisticRegressionLearner
1111
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
1212
from Orange.evaluation import Results, TestOnTestData
1313
from Orange.evaluation.scoring import ClassificationScore, RegressionScore, \
@@ -247,6 +247,35 @@ class NewRegressionScore(RegressionScore):
247247
del Score.registry["NewClassificationScore"]
248248
del Score.registry["NewRegressionScore"]
249249

250+
def test_target_changing(self):
251+
data = Table("iris")
252+
w = self.widget #: OWTestLearners
253+
254+
w.n_folds = 2
255+
self.send_signal(self.widget.Inputs.train_data, data)
256+
self.send_signal(self.widget.Inputs.learner,
257+
LogisticRegressionLearner(), 0, wait=5000)
258+
259+
average_auc = float(w.view.model().item(0, 1).text())
260+
261+
w.class_selection = "Iris-setosa"
262+
w._on_target_class_changed()
263+
setosa_auc = float(w.view.model().item(0, 1).text())
264+
265+
w.class_selection = "Iris-versicolor"
266+
w._on_target_class_changed()
267+
versicolor_auc = float(w.view.model().item(0, 1).text())
268+
269+
w.class_selection = "Iris-virginica"
270+
w._on_target_class_changed()
271+
virginica_auc = float(w.view.model().item(0, 1).text())
272+
273+
self.assertGreater(average_auc, versicolor_auc)
274+
self.assertGreater(average_auc, virginica_auc)
275+
self.assertLess(average_auc, setosa_auc)
276+
self.assertGreater(setosa_auc, versicolor_auc)
277+
self.assertGreater(setosa_auc, virginica_auc)
278+
250279

251280
class TestHelpers(unittest.TestCase):
252281
def test_results_one_vs_rest(self):

0 commit comments

Comments
 (0)