Skip to content

Commit fdcee4b

Browse files
committed
scoring: find usable scorers for non-built-in problem types
1 parent ea03b36 commit fdcee4b

File tree

4 files changed

+45
-17
lines changed

4 files changed

+45
-17
lines changed

Orange/evaluation/scoring.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ class Score(metaclass=ScoreMetaType):
6666
name = None
6767
long_name = None #: A short user-readable name (e.g. a few words)
6868

69+
# definition of built-in Scorers:
70+
is_built_in = True
71+
# If true Scorer is shown in Scorer table by default.
72+
shown_by_default = True
73+
# Placeholder for a problem type of non-built-in Scorers
74+
problem_type = "built-in"
75+
6976
def __new__(cls, results=None, **kwargs):
7077
self = super().__new__(cls)
7178
if results is not None:
@@ -270,6 +277,7 @@ class LogLoss(ClassificationScore):
270277
271278
"""
272279
__wraps__ = skl_metrics.log_loss
280+
shown_by_default = False
273281

274282
def compute_score(self, results, eps=1e-15, normalize=True,
275283
sample_weight=None):
@@ -285,6 +293,7 @@ def compute_score(self, results, eps=1e-15, normalize=True,
285293

286294
class Specificity(ClassificationScore):
287295
is_binary = True
296+
shown_by_default = False
288297

289298
@staticmethod
290299
def calculate_weights(results):
@@ -360,6 +369,7 @@ class R2(RegressionScore):
360369

361370
class CVRMSE(RegressionScore):
362371
long_name = "Coefficient of variation of the RMSE"
372+
shown_by_default = False
363373

364374
def compute_score(self, results):
365375
mean = np.nanmean(results.actual)

Orange/widgets/evaluate/owpredictions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def _call_predictors(self):
283283
def _update_scores(self):
284284
model = self.score_table.model
285285
model.clear()
286-
scorers = usable_scorers(self.class_var) if self.class_var else []
286+
scorers = usable_scorers(self.data)
287287
self.score_table.update_header(scorers)
288288
errors = []
289289
for pred in self.predictors:

Orange/widgets/evaluate/owtestandscore.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,10 +511,10 @@ def _which_missing_data(self):
511511
# - we don't gain much with it
512512
# - it complicates the unit tests
513513
def _update_scorers(self):
514-
if self.data and self.data.domain.class_var:
515-
new_scorers = usable_scorers(self.data.domain.class_var)
516-
else:
517-
new_scorers = []
514+
new_scorers = []
515+
if self.data:
516+
new_scorers = usable_scorers(self.data)
517+
518518
# Don't unnecessarily reset the combo because this would always reset
519519
# comparison_criterion; we also set it explicitly, though, for clarity
520520
if new_scorers != self.scorers:

Orange/widgets/evaluate/utils.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import warnings
22
from functools import partial
3-
from itertools import chain
43

54
import numpy as np
65

@@ -11,7 +10,7 @@
1110
QSortFilterProxyModel
1211
from sklearn.exceptions import UndefinedMetricWarning
1312

14-
from Orange.data import Variable, DiscreteVariable, ContinuousVariable
13+
from Orange.data import Table, DiscreteVariable, ContinuousVariable
1514
from Orange.evaluation import scoring
1615
from Orange.widgets import gui
1716
from Orange.widgets.utils.tableview import table_selection_to_mime_data
@@ -78,14 +77,33 @@ def learner_name(learner):
7877
return getattr(learner, "name", type(learner).__name__)
7978

8079

81-
def usable_scorers(target: Variable):
82-
order = {name: i
83-
for i, name in enumerate(BUILTIN_SCORERS_ORDER[type(target)])}
80+
def usable_scorers(data: Table):
81+
if not data:
82+
return []
83+
84+
problem_type = data.attributes.get("problem_type", None)
85+
target = data.domain.class_var
86+
8487
# 'abstract' is retrieved from __dict__ to avoid inheriting
85-
usable = (cls for cls in scoring.Score.registry.values()
86-
if cls.is_scalar and not cls.__dict__.get("abstract")
87-
and isinstance(target, cls.class_types))
88-
return sorted(usable, key=lambda cls: order.get(cls.name, 99))
88+
scorer_candidates = [cls for cls in scoring.Score.registry.values()
89+
if cls.is_scalar and not cls.__dict__.get("abstract")]
90+
91+
# If problem_type is not specified and 'domain.class_var' is set
92+
# use builtin scorers and don't brake the default behaviour.
93+
usable = []
94+
if problem_type is None and target:
95+
order = {name: i
96+
for i, name in enumerate(BUILTIN_SCORERS_ORDER[type(target)])}
97+
usable = sorted((cls for cls in scorer_candidates
98+
if isinstance(target, cls.class_types) and cls.is_built_in),
99+
key=lambda cls: order.get(cls.name, 99))
100+
101+
elif problem_type and data.domain.class_vars:
102+
usable = [cls for cls in scoring.Score.registry.values()
103+
if not cls.is_built_in and cls.problem_type == problem_type and
104+
all(isinstance(target, cls.class_types) for target in data.domain.class_vars)]
105+
106+
return usable
89107

90108

91109
def scorer_caller(scorer, ovr_results, target=None):
@@ -131,9 +149,9 @@ def is_bad(x):
131149

132150

133151
class ScoreTable(OWComponent, QObject):
134-
shown_scores = \
135-
Setting(set(chain(*BUILTIN_SCORERS_ORDER.values())))
136-
152+
shown_scores = Setting(set(scorer.name for scorer in
153+
scoring.Score.registry.values() if
154+
scorer.shown_by_default))
137155
shownScoresChanged = Signal()
138156

139157
class ItemDelegate(QStyledItemDelegate):

0 commit comments

Comments
 (0)