Skip to content

Commit db964ae

Browse files
committed
Calibration: data_to_model_domain call base_model's function
1 parent f7fd48c commit db964ae

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

Orange/classification/calibration.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
from collections import Callable
2+
13
import numpy as np
24
from sklearn.isotonic import IsotonicRegression
35
from sklearn.calibration import _SigmoidCalibration
46

57
from Orange.classification import Model, Learner
8+
from Orange.data import Table
69
from Orange.evaluation import TestOnTrainingData
710
from Orange.evaluation.performance_curves import Curves
811

912
__all__ = ["ThresholdClassifier", "ThresholdLearner",
1013
"CalibratedLearner", "CalibratedClassifier"]
1114

15+
from Orange.util import dummy_callback
16+
1217

1318
class ThresholdClassifier(Model):
1419
"""
@@ -31,6 +36,11 @@ def __init__(self, base_model, threshold):
3136
self.base_model = base_model
3237
self.threshold = threshold
3338

39+
def data_to_model_domain(
40+
self, data: Table, progress_callback: Callable = dummy_callback
41+
) -> Table:
42+
return self.base_model.data_to_model_domain(data, progress_callback)
43+
3444
def __call__(self, data, ret=Model.Value):
3545
probs = self.base_model(data, ret=Model.Probs)
3646
if ret == Model.Probs:
@@ -104,6 +114,11 @@ def __init__(self, base_model, calibrators):
104114
self.calibrators = calibrators
105115
self.name = f"{base_model.name}, calibrated"
106116

117+
def data_to_model_domain(
118+
self, data: Table, progress_callback: Callable = dummy_callback
119+
) -> Table:
120+
return self.base_model.data_to_model_domain(data, progress_callback)
121+
107122
def __call__(self, data, ret=Model.Value):
108123
probs = self.base_model(data, Model.Probs)
109124
cal_probs = self.calibrated_probs(probs)

Orange/classification/tests/test_calibration.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from Orange.base import Model
7+
from Orange.classification import LogisticRegressionLearner
78
from Orange.classification.calibration import \
89
ThresholdLearner, ThresholdClassifier, \
910
CalibratedLearner, CalibratedClassifier
@@ -65,6 +66,18 @@ def test_non_binary_base(self):
6566
base_model.domain.class_var.is_discrete = False
6667
self.assertRaises(ValueError, ThresholdClassifier, base_model, 0.5)
6768

69+
@staticmethod
70+
def test_np_data():
71+
"""
72+
Test ThresholdModel with numpy data.
73+
When passing numpy data to model they should be already
74+
transformed to models domain since model do not know how to do it.
75+
"""
76+
data = Table('heart_disease')
77+
base_learner = LogisticRegressionLearner()
78+
model = ThresholdLearner(base_learner)(data)
79+
model(model.data_to_model_domain(data).X)
80+
6881

6982
class TestThresholdLearner(unittest.TestCase):
7083
@patch("Orange.evaluation.performance_curves.Curves.from_results")
@@ -166,6 +179,18 @@ def test_calibrated_probs(self):
166179
calprobs = self.model.calibrated_probs(self.probs)
167180
np.testing.assert_almost_equal(calprobs, expprobs)
168181

182+
@staticmethod
183+
def test_np_data():
184+
"""
185+
Test CalibratedClassifier with numpy data.
186+
When passing numpy data to model they should be already
187+
transformed to models domain since model do not know how to do it.
188+
"""
189+
data = Table('heart_disease')
190+
base_learner = LogisticRegressionLearner()
191+
model = CalibratedLearner(base_learner)(data)
192+
model(model.data_to_model_domain(data).X)
193+
169194

170195
class TestCalibratedLearner(unittest.TestCase):
171196
@patch("Orange.classification.calibration._SigmoidCalibration.fit")
@@ -201,3 +226,7 @@ def test_fit_storage(self, test_on_training, sigmoid_fit):
201226
for call, cls_probs in zip(sigmoid_fit.call_args_list,
202227
res.probabilities[0].T):
203228
np.testing.assert_equal(call[0][0], cls_probs)
229+
230+
231+
if __name__ == "__main__":
232+
unittest.main()

0 commit comments

Comments
 (0)