Skip to content

Commit 064c928

Browse files
committed
Calibration plot: Add context settings
1 parent dea15bf commit 064c928

File tree

2 files changed

+32
-43
lines changed

2 files changed

+32
-43
lines changed
Lines changed: 23 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,30 @@
1+
from Orange.data import Variable
12
from Orange.widgets import settings
2-
from Orange.widgets.utils import getdeepattr
33

44

55
class EvaluationResultsContextHandler(settings.ContextHandler):
6-
def __init__(self, targetAttr, selectedAttr):
7-
super().__init__()
8-
self.targetAttr, self.selectedAttr = targetAttr, selectedAttr
6+
"""Context handler for evaluation results"""
97

10-
#noinspection PyMethodOverriding
11-
def match(self, context, cnames, cvalues):
12-
return (cnames, cvalues) == (
13-
context.classifierNames, context.classValues) and 2
8+
def open_context(self, widget, classes, classifier_names):
9+
if isinstance(classes, Variable):
10+
if classes.is_discrete:
11+
classes = classes.values
12+
else:
13+
classes = None
14+
super().open_context(widget, classes, classifier_names)
1415

15-
def fast_save(self, widget, name, value):
16-
context = widget.current_context
17-
if name == self.targetAttr:
18-
context.targetClass = value
19-
elif name == self.selectedAttr:
20-
context.selectedClassifiers = list(value)
16+
def new_context(self, classes, classifier_names):
17+
context = super().new_context()
18+
context.classes = classes
19+
context.classifier_names = classifier_names
20+
return context
2121

22-
def settings_from_widget(self, widget, *args):
23-
super().settings_from_widget(widget, *args)
24-
context = widget.current_context
25-
context.targetClass = getdeepattr(widget, self.targetAttr)
26-
context.selectedClassifiers = list(getdeepattr(self.selectedAttr))
27-
28-
def settings_to_widget(self, widget, *args):
29-
super().settings_to_widget(widget, *args)
30-
context = widget.current_context
31-
if context.targetClass is not None:
32-
setattr(widget, self.targetAttr, context.targetClass)
33-
if context.selectedClassifiers is not None:
34-
setattr(widget, self.selectedAttr, context.selectedClassifiers)
35-
36-
#noinspection PyMethodOverriding
37-
def find_or_create_context(self, widget, results):
38-
cnames = [c.name for c in results.classifiers]
39-
cvalues = results.classValues
40-
context, isNew = super().find_or_create_context(
41-
widget, results.classifierNames, results.classValues)
42-
if isNew:
43-
context.classifierNames = results.classifierNames
44-
context.classValues = results.classValues
45-
context.selectedClassifiers = None
46-
context.targetClass = None
47-
return context, isNew
22+
def match(self, context, classes, classifier_names):
23+
if classifier_names != context.classifier_names:
24+
return self.NO_MATCH
25+
elif isinstance(classes, Variable) and classes.is_continuous:
26+
return (self.PERFECT_MATCH if context.classes is None
27+
else self.NO_MATCH)
28+
else:
29+
return (self.PERFECT_MATCH if context.classes == classes
30+
else self.NO_MATCH)

Orange/widgets/evaluate/owcalibrationplot.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from Orange.evaluation import Results
1313
from Orange.evaluation.performance_curves import Curves
1414
from Orange.widgets import widget, gui, settings
15+
from Orange.widgets.evaluate.contexthandlers import \
16+
EvaluationResultsContextHandler
1517
from Orange.widgets.evaluate.utils import \
1618
check_results_adequacy, results_for_preview
1719
from Orange.widgets.utils import colorpalette, colorbrewer
@@ -84,9 +86,9 @@ class Information(widget.OWWidget.Information):
8486
no_out + "select a single model - the widget can output only one")
8587
non_binary_class = Msg(no_out + "cannot calibrate non-binary classes")
8688

87-
88-
target_index = settings.Setting(0)
89-
selected_classifiers = settings.Setting([])
89+
settingsHandler = EvaluationResultsContextHandler()
90+
target_index = settings.ContextSetting(0)
91+
selected_classifiers = settings.ContextSetting([])
9092
score = settings.Setting(0)
9193
output_calibration = settings.Setting(0)
9294
fold_curves = settings.Setting(False)
@@ -168,6 +170,7 @@ def __init__(self):
168170

169171
@Inputs.evaluation_results
170172
def set_results(self, results):
173+
self.closeContext()
171174
self.clear()
172175
results = check_results_adequacy(results, self.Error, check_nan=False)
173176
if results is not None and not results.actual.size:
@@ -177,6 +180,9 @@ def set_results(self, results):
177180
self.results = results
178181
if self.results is not None:
179182
self._initialize(results)
183+
class_var = self.results.domain.class_var
184+
self.target_index = int(len(class_var.values) == 2)
185+
self.openContext(class_var, self.classifier_names)
180186
self._replot()
181187
self.apply()
182188

0 commit comments

Comments
 (0)