Skip to content

Commit 2d54536

Browse files
committed
Calibration: prevent wrongly overwriting the domain
1 parent 2d216ee commit 2d54536

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

Orange/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def __call__(self, data, progress_callback=None):
136136
progress_callback(0.1, "Fitting...")
137137
model = self._fit_model(data)
138138
model.used_vals = [np.unique(y).astype(int) for y in data.Y[:, None].T]
139-
model.domain = data.domain
139+
if not hasattr(model, "domain") or model.domain is None:
140+
# some models set domain themself and it should be respected
141+
# e.g. calibration learners set the base_learner's domain which
142+
# would be wrongly overwritten if we set it here for any model
143+
model.domain = data.domain
140144
model.supports_multiclass = self.supports_multiclass
141145
model.name = self.name
142146
model.original_domain = origdomain

Orange/classification/calibration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
import numpy as np
23
from sklearn.isotonic import IsotonicRegression
34
from sklearn.calibration import _SigmoidCalibration

Orange/classification/tests/test_calibration.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from Orange.base import Model
7+
from Orange.classification import LogisticRegressionLearner
78
from Orange.classification.calibration import \
89
ThresholdLearner, ThresholdClassifier, \
910
CalibratedLearner, CalibratedClassifier
@@ -65,6 +66,17 @@ def test_non_binary_base(self):
6566
base_model.domain.class_var.is_discrete = False
6667
self.assertRaises(ValueError, ThresholdClassifier, base_model, 0.5)
6768

69+
def test_np_data(self):
70+
"""
71+
Test ThresholdModel with numpy data.
72+
When passing numpy data to model they should be already
73+
transformed to models domain since model do not know how to do it.
74+
"""
75+
data = Table('heart_disease')
76+
base_learner = LogisticRegressionLearner()
77+
model = ThresholdLearner(base_learner)(data)
78+
model(model.data_to_model_domain(data).X)
79+
6880

6981
class TestThresholdLearner(unittest.TestCase):
7082
@patch("Orange.evaluation.performance_curves.Curves.from_results")
@@ -169,6 +181,17 @@ def test_calibrated_probs(self):
169181
calprobs = self.model.calibrated_probs(self.probs)
170182
np.testing.assert_almost_equal(calprobs, expprobs)
171183

184+
def test_np_data(self):
185+
"""
186+
Test CalibratedClassifier with numpy data.
187+
When passing numpy data to model they should be already
188+
transformed to models domain since model do not know how to do it.
189+
"""
190+
data = Table('heart_disease')
191+
base_learner = LogisticRegressionLearner()
192+
model = CalibratedLearner(base_learner)(data)
193+
model(model.data_to_model_domain(data).X)
194+
172195

173196
class TestCalibratedLearner(unittest.TestCase):
174197
@patch("Orange.classification.calibration._SigmoidCalibration.fit")
@@ -207,3 +230,7 @@ def test_fit_storage(self, test_on_training, sigmoid_fit):
207230
for call, cls_probs in zip(sigmoid_fit.call_args_list,
208231
res.probabilities[0].T):
209232
np.testing.assert_equal(call[0][0], cls_probs)
233+
234+
235+
if __name__ == "__main__":
236+
unittest.main()

0 commit comments

Comments
 (0)