|
1 | 1 | import warnings |
2 | 2 | from functools import partial |
3 | | -from itertools import chain |
4 | 3 |
|
5 | 4 | import numpy as np |
6 | 5 |
|
|
11 | 10 | QSortFilterProxyModel |
12 | 11 | from sklearn.exceptions import UndefinedMetricWarning |
13 | 12 |
|
14 | | -from Orange.data import Variable, DiscreteVariable, ContinuousVariable |
| 13 | +from Orange.data import Table, DiscreteVariable, ContinuousVariable |
15 | 14 | from Orange.evaluation import scoring |
16 | 15 | from Orange.widgets import gui |
17 | 16 | from Orange.widgets.utils.tableview import table_selection_to_mime_data |
@@ -78,14 +77,33 @@ def learner_name(learner): |
78 | 77 | return getattr(learner, "name", type(learner).__name__) |
79 | 78 |
|
80 | 79 |
|
81 | | -def usable_scorers(target: Variable): |
82 | | - order = {name: i |
83 | | - for i, name in enumerate(BUILTIN_SCORERS_ORDER[type(target)])} |
| 80 | +def usable_scorers(data: Table): |
| 81 | + if not data: |
| 82 | + return [] |
| 83 | + |
| 84 | + problem_type = data.attributes.get("problem_type", None) |
| 85 | + target = data.domain.class_var |
| 86 | + |
84 | 87 | # 'abstract' is retrieved from __dict__ to avoid inheriting |
85 | | - usable = (cls for cls in scoring.Score.registry.values() |
86 | | - if cls.is_scalar and not cls.__dict__.get("abstract") |
87 | | - and isinstance(target, cls.class_types)) |
88 | | - return sorted(usable, key=lambda cls: order.get(cls.name, 99)) |
| 88 | + scorer_candidates = [cls for cls in scoring.Score.registry.values() |
| 89 | + if cls.is_scalar and not cls.__dict__.get("abstract")] |
| 90 | + |
| 91 | + # If problem_type is not specified and 'domain.class_var' is set |
| 92 | + # use builtin scorers and don't brake the default behaviour. |
| 93 | + usable = [] |
| 94 | + if problem_type is None and target: |
| 95 | + order = {name: i |
| 96 | + for i, name in enumerate(BUILTIN_SCORERS_ORDER[type(target)])} |
| 97 | + usable = sorted((cls for cls in scorer_candidates |
| 98 | + if isinstance(target, cls.class_types) and cls.is_built_in), |
| 99 | + key=lambda cls: order.get(cls.name, 99)) |
| 100 | + |
| 101 | + elif problem_type and data.domain.class_vars: |
| 102 | + usable = [cls for cls in scoring.Score.registry.values() |
| 103 | + if not cls.is_built_in and cls.problem_type == problem_type and |
| 104 | + all(isinstance(target, cls.class_types) for target in data.domain.class_vars)] |
| 105 | + |
| 106 | + return usable |
89 | 107 |
|
90 | 108 |
|
91 | 109 | def scorer_caller(scorer, ovr_results, target=None): |
@@ -131,9 +149,9 @@ def is_bad(x): |
131 | 149 |
|
132 | 150 |
|
133 | 151 | class ScoreTable(OWComponent, QObject): |
134 | | - shown_scores = \ |
135 | | - Setting(set(chain(*BUILTIN_SCORERS_ORDER.values()))) |
136 | | - |
| 152 | + shown_scores = Setting(set(scorer.name for scorer in |
| 153 | + scoring.Score.registry.values() if |
| 154 | + scorer.shown_by_default)) |
137 | 155 | shownScoresChanged = Signal() |
138 | 156 |
|
139 | 157 | class ItemDelegate(QStyledItemDelegate): |
|
0 commit comments