Skip to content

Commit 9deed46

Browse files
authored
Merge pull request #3881 from janezd/calibration-plot-curves
[ENH] Calibration plot (add performance curves) and a new Calibrated Learner widget
2 parents 15d4fe6 + 864d7b5 commit 9deed46

24 files changed

+2036
-204
lines changed

Orange/classification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from .rules import *
2020
from .sgd import *
2121
from .neural_network import *
22+
from .calibration import *
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import numpy as np
2+
from sklearn.isotonic import IsotonicRegression
3+
from sklearn.calibration import _SigmoidCalibration
4+
5+
from Orange.classification import Model, Learner
6+
from Orange.evaluation import TestOnTrainingData
7+
from Orange.evaluation.performance_curves import Curves
8+
9+
__all__ = ["ThresholdClassifier", "ThresholdLearner",
10+
"CalibratedLearner", "CalibratedClassifier"]
11+
12+
13+
class ThresholdClassifier(Model):
14+
"""
15+
A model that wraps a binary model and sets a different threshold.
16+
17+
The target class is the class with index 1. A data instances is classified
18+
to class 1 it the probability of this class equals or exceeds the threshold
19+
20+
Attributes:
21+
base_model (Orange.classification.Model): base mode
22+
threshold (float): decision threshold
23+
"""
24+
def __init__(self, base_model, threshold):
25+
if not base_model.domain.class_var.is_discrete \
26+
or len(base_model.domain.class_var.values) != 2:
27+
raise ValueError("ThresholdClassifier requires a binary class")
28+
29+
super().__init__(base_model.domain, base_model.original_domain)
30+
self.name = f"{base_model.name}, thresh={threshold:.2f}"
31+
self.base_model = base_model
32+
self.threshold = threshold
33+
34+
def __call__(self, data, ret=Model.Value):
35+
probs = self.base_model(data, ret=Model.Probs)
36+
if ret == Model.Probs:
37+
return probs
38+
class_probs = probs[:, 1].ravel()
39+
with np.errstate(invalid="ignore"): # we fix nanx below
40+
vals = (class_probs >= self.threshold).astype(float)
41+
vals[np.isnan(class_probs)] = np.nan
42+
if ret == Model.Value:
43+
return vals
44+
else:
45+
return vals, probs
46+
47+
48+
class ThresholdLearner(Learner):
49+
"""
50+
A learner that runs another learner and then finds the optimal threshold
51+
for CA or F1 on the training data.
52+
53+
Attributes:
54+
base_leaner (Learner): base learner
55+
threshold_criterion (int):
56+
`ThresholdLearner.OptimizeCA` or `ThresholdLearner.OptimizeF1`
57+
"""
58+
__returns__ = ThresholdClassifier
59+
60+
OptimizeCA, OptimizeF1 = range(2)
61+
62+
def __init__(self, base_learner, threshold_criterion=OptimizeCA):
63+
super().__init__()
64+
self.base_learner = base_learner
65+
self.threshold_criterion = threshold_criterion
66+
67+
def fit_storage(self, data):
68+
"""
69+
Induce a model using the provided `base_learner`, compute probabilities
70+
on training data and the find the optimal decision thresholds. In case
71+
of ties, select the threshold that is closest to 0.5.
72+
"""
73+
if not data.domain.class_var.is_discrete \
74+
or len(data.domain.class_var.values) != 2:
75+
raise ValueError("ThresholdLearner requires a binary class")
76+
77+
res = TestOnTrainingData(data, [self.base_learner], store_models=True)
78+
model = res.models[0, 0]
79+
curves = Curves.from_results(res)
80+
curve = [curves.ca, curves.f1][self.threshold_criterion]()
81+
# In case of ties, we want the optimal threshold that is closest to 0.5
82+
best_threshs = curves.probs[curve == np.max(curve)]
83+
threshold = best_threshs[min(np.searchsorted(best_threshs, 0.5),
84+
len(best_threshs) - 1)]
85+
return ThresholdClassifier(model, threshold)
86+
87+
88+
class CalibratedClassifier(Model):
89+
"""
90+
A model that wraps another model and recalibrates probabilities
91+
92+
Attributes:
93+
base_model (Mode): base mode
94+
calibrators (list of callable):
95+
list of functions that get a vector of probabilities and return
96+
calibrated probabilities
97+
"""
98+
def __init__(self, base_model, calibrators):
99+
if not base_model.domain.class_var.is_discrete:
100+
raise ValueError("CalibratedClassifier requires a discrete target")
101+
102+
super().__init__(base_model.domain, base_model.original_domain)
103+
self.base_model = base_model
104+
self.calibrators = calibrators
105+
self.name = f"{base_model.name}, calibrated"
106+
107+
def __call__(self, data, ret=Model.Value):
108+
probs = self.base_model(data, Model.Probs)
109+
cal_probs = self.calibrated_probs(probs)
110+
if ret == Model.Probs:
111+
return cal_probs
112+
vals = np.argmax(cal_probs, axis=1)
113+
if ret == Model.Value:
114+
return vals
115+
else:
116+
return vals, cal_probs
117+
118+
def calibrated_probs(self, probs):
119+
if self.calibrators:
120+
ps = np.hstack(
121+
tuple(
122+
calibr.predict(cls_probs).reshape(-1, 1)
123+
for calibr, cls_probs in zip(self.calibrators, probs.T)))
124+
else:
125+
ps = probs.copy()
126+
sums = np.sum(ps, axis=1)
127+
zero_sums = sums == 0
128+
with np.errstate(invalid="ignore"): # handled below
129+
ps /= sums[:, None]
130+
if zero_sums.any():
131+
ps[zero_sums] = 1 / ps.shape[1]
132+
return ps
133+
134+
135+
class CalibratedLearner(Learner):
136+
"""
137+
Probability calibration for learning algorithms
138+
139+
This learner that wraps another learner, so that after training, it predicts
140+
the probabilities on training data and calibrates them using sigmoid or
141+
isotonic calibration. It then returns a :obj:`CalibratedClassifier`.
142+
143+
Attributes:
144+
base_learner (Learner): base learner
145+
calibration_method (int):
146+
`CalibratedLearner.Sigmoid` or `CalibratedLearner.Isotonic`
147+
"""
148+
__returns__ = CalibratedClassifier
149+
150+
Sigmoid, Isotonic = range(2)
151+
152+
def __init__(self, base_learner, calibration_method=Sigmoid):
153+
super().__init__()
154+
self.base_learner = base_learner
155+
self.calibration_method = calibration_method
156+
157+
def fit_storage(self, data):
158+
"""
159+
Induce a model using the provided `base_learner`, compute probabilities
160+
on training data and use scipy's `_SigmoidCalibration` or
161+
`IsotonicRegression` to prepare calibrators.
162+
"""
163+
res = TestOnTrainingData(data, [self.base_learner], store_models=True)
164+
model = res.models[0, 0]
165+
probabilities = res.probabilities[0]
166+
return self.get_model(model, res.actual, probabilities)
167+
168+
def get_model(self, model, ytrue, probabilities):
169+
if self.calibration_method == CalibratedLearner.Sigmoid:
170+
fitter = _SigmoidCalibration()
171+
else:
172+
fitter = IsotonicRegression(out_of_bounds='clip')
173+
probabilities[np.isinf(probabilities)] = 1
174+
calibrators = [fitter.fit(cls_probs, ytrue)
175+
for cls_idx, cls_probs in enumerate(probabilities.T)]
176+
return CalibratedClassifier(model, calibrators)
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import unittest
2+
from unittest.mock import Mock, patch
3+
4+
import numpy as np
5+
6+
from Orange.base import Model
7+
from Orange.classification.calibration import \
8+
ThresholdLearner, ThresholdClassifier, \
9+
CalibratedLearner, CalibratedClassifier
10+
from Orange.data import Table
11+
12+
13+
class TestThresholdClassifier(unittest.TestCase):
14+
def setUp(self):
15+
probs1 = np.array([0.3, 0.5, 0.2, 0.8, 0.9, 0]).reshape(-1, 1)
16+
self.probs = np.hstack((1 - probs1, probs1))
17+
base_model = Mock(return_value=self.probs)
18+
base_model.domain.class_var.is_discrete = True
19+
base_model.domain.class_var.values = ["a", "b"]
20+
self.model = ThresholdClassifier(base_model, 0.5)
21+
self.data = Mock()
22+
23+
def test_threshold(self):
24+
vals = self.model(self.data)
25+
np.testing.assert_equal(vals, [0, 1, 0, 1, 1, 0])
26+
27+
self.model.threshold = 0.8
28+
vals = self.model(self.data)
29+
np.testing.assert_equal(vals, [0, 0, 0, 1, 1, 0])
30+
31+
self.model.threshold = 0
32+
vals = self.model(self.data)
33+
np.testing.assert_equal(vals, [1] * 6)
34+
35+
def test_return_types(self):
36+
vals = self.model(self.data, ret=Model.Value)
37+
np.testing.assert_equal(vals, [0, 1, 0, 1, 1, 0])
38+
39+
vals = self.model(self.data)
40+
np.testing.assert_equal(vals, [0, 1, 0, 1, 1, 0])
41+
42+
probs = self.model(self.data, ret=Model.Probs)
43+
np.testing.assert_equal(probs, self.probs)
44+
45+
vals, probs = self.model(self.data, ret=Model.ValueProbs)
46+
np.testing.assert_equal(vals, [0, 1, 0, 1, 1, 0])
47+
np.testing.assert_equal(probs, self.probs)
48+
49+
def test_nans(self):
50+
self.probs[1, :] = np.nan
51+
vals, probs = self.model(self.data, ret=Model.ValueProbs)
52+
np.testing.assert_equal(vals, [0, np.nan, 0, 1, 1, 0])
53+
np.testing.assert_equal(probs, self.probs)
54+
55+
def test_non_binary_base(self):
56+
base_model = Mock()
57+
base_model.domain.class_var.is_discrete = True
58+
base_model.domain.class_var.values = ["a"]
59+
self.assertRaises(ValueError, ThresholdClassifier, base_model, 0.5)
60+
61+
base_model.domain.class_var.values = ["a", "b", "c"]
62+
self.assertRaises(ValueError, ThresholdClassifier, base_model, 0.5)
63+
64+
base_model.domain.class_var = Mock()
65+
base_model.domain.class_var.is_discrete = False
66+
self.assertRaises(ValueError, ThresholdClassifier, base_model, 0.5)
67+
68+
69+
class TestThresholdLearner(unittest.TestCase):
70+
@patch("Orange.evaluation.performance_curves.Curves.from_results")
71+
@patch("Orange.classification.calibration.TestOnTrainingData")
72+
def test_fit_storage(self, test_on_training, curves_from_results):
73+
curves_from_results.return_value = curves = Mock()
74+
curves.probs = np.array([0.1, 0.15, 0.3, 0.45, 0.6, 0.8])
75+
curves.ca = lambda: np.array([0.1, 0.7, 0.4, 0.4, 0.3, 0.1])
76+
curves.f1 = lambda: np.array([0.1, 0.2, 0.4, 0.4, 0.3, 0.1])
77+
model = Mock()
78+
model.domain.class_var.is_discrete = True
79+
model.domain.class_var.values = ("a", "b")
80+
data = Table("heart_disease")
81+
learner = Mock()
82+
test_on_training.return_value = res = Mock()
83+
res.models = np.array([[model]])
84+
test_on_training.return_value = res
85+
86+
thresh_learner = ThresholdLearner(
87+
base_learner=learner,
88+
threshold_criterion=ThresholdLearner.OptimizeCA)
89+
thresh_model = thresh_learner(data)
90+
self.assertEqual(thresh_model.threshold, 0.15)
91+
args, kwargs = test_on_training.call_args
92+
self.assertEqual(len(args), 2)
93+
self.assertIs(args[0], data)
94+
self.assertIs(args[1][0], learner)
95+
self.assertEqual(len(args[1]), 1)
96+
self.assertEqual(kwargs, {"store_models": 1})
97+
98+
thresh_learner = ThresholdLearner(
99+
base_learner=learner,
100+
threshold_criterion=ThresholdLearner.OptimizeF1)
101+
thresh_model = thresh_learner(data)
102+
self.assertEqual(thresh_model.threshold, 0.45)
103+
104+
def test_non_binary_class(self):
105+
thresh_learner = ThresholdLearner(
106+
base_learner=Mock(),
107+
threshold_criterion=ThresholdLearner.OptimizeF1)
108+
109+
data = Mock()
110+
data.domain.class_var.is_discrete = True
111+
data.domain.class_var.values = ["a"]
112+
self.assertRaises(ValueError, thresh_learner.fit_storage, data)
113+
114+
data.domain.class_var.values = ["a", "b", "c"]
115+
self.assertRaises(ValueError, thresh_learner.fit_storage, data)
116+
117+
data.domain.class_var = Mock()
118+
data.domain.class_var.is_discrete = False
119+
self.assertRaises(ValueError, thresh_learner.fit_storage, data)
120+
121+
122+
class TestCalibratedClassifier(unittest.TestCase):
123+
def setUp(self):
124+
probs1 = np.array([0.3, 0.5, 0.2, 0.8, 0.9, 0]).reshape(-1, 1)
125+
self.probs = np.hstack((1 - probs1, probs1))
126+
base_model = Mock(return_value=self.probs)
127+
base_model.domain.class_var.is_discrete = True
128+
base_model.domain.class_var.values = ["a", "b"]
129+
self.model = CalibratedClassifier(base_model, None)
130+
self.data = Mock()
131+
132+
def test_call(self):
133+
calprobs = np.arange(self.probs.size).reshape(self.probs.shape)
134+
calprobs = calprobs / np.sum(calprobs, axis=1)[:, None]
135+
calprobs[-1] = [0.7, 0.3]
136+
self.model.calibrated_probs = Mock(return_value=calprobs)
137+
138+
probs = self.model(self.data, ret=Model.Probs)
139+
self.model.calibrated_probs.assert_called_with(self.probs)
140+
np.testing.assert_almost_equal(probs, calprobs)
141+
142+
vals = self.model(self.data, ret=Model.Value)
143+
np.testing.assert_almost_equal(vals, [1, 1, 1, 1, 1, 0])
144+
145+
vals, probs = self.model(self.data, ret=Model.ValueProbs)
146+
np.testing.assert_almost_equal(probs, calprobs)
147+
np.testing.assert_almost_equal(vals, [1, 1, 1, 1, 1, 0])
148+
149+
def test_calibrated_probs(self):
150+
self.model.calibrators = None
151+
calprobs = self.model.calibrated_probs(self.probs)
152+
np.testing.assert_equal(calprobs, self.probs)
153+
self.assertIsNot(calprobs, self.probs)
154+
155+
calibrator = Mock()
156+
calibrator.predict = lambda x: x**2
157+
self.model.calibrators = [calibrator] * 2
158+
calprobs = self.model.calibrated_probs(self.probs)
159+
expprobs = self.probs ** 2 / np.sum(self.probs ** 2, axis=1)[:, None]
160+
np.testing.assert_almost_equal(calprobs, expprobs)
161+
162+
self.probs[1] = 0
163+
self.probs[2] = np.nan
164+
expprobs[1] = 0.5
165+
expprobs[2] = np.nan
166+
calprobs = self.model.calibrated_probs(self.probs)
167+
np.testing.assert_almost_equal(calprobs, expprobs)
168+
169+
170+
class TestCalibratedLearner(unittest.TestCase):
171+
@patch("Orange.classification.calibration._SigmoidCalibration.fit")
172+
@patch("Orange.classification.calibration.TestOnTrainingData")
173+
def test_fit_storage(self, test_on_training, sigmoid_fit):
174+
data = Table("heart_disease")
175+
learner = Mock()
176+
177+
model = Mock()
178+
model.domain.class_var.is_discrete = True
179+
model.domain.class_var.values = ("a", "b")
180+
181+
test_on_training.return_value = res = Mock()
182+
res.models = np.array([[model]])
183+
res.probabilities = np.arange(20, dtype=float).reshape(1, 5, 4)
184+
test_on_training.return_value = res
185+
186+
sigmoid_fit.return_value = Mock()
187+
188+
cal_learner = CalibratedLearner(
189+
base_learner=learner, calibration_method=CalibratedLearner.Sigmoid)
190+
cal_model = cal_learner(data)
191+
192+
self.assertIs(cal_model.base_model, model)
193+
self.assertEqual(cal_model.calibrators, [sigmoid_fit.return_value] * 4)
194+
args, kwargs = test_on_training.call_args
195+
self.assertEqual(len(args), 2)
196+
self.assertIs(args[0], data)
197+
self.assertIs(args[1][0], learner)
198+
self.assertEqual(len(args[1]), 1)
199+
self.assertEqual(kwargs, {"store_models": 1})
200+
201+
for call, cls_probs in zip(sigmoid_fit.call_args_list,
202+
res.probabilities[0].T):
203+
np.testing.assert_equal(call[0][0], cls_probs)

0 commit comments

Comments
 (0)