Skip to content

Commit 7a7c635

Browse files
authored
Merge pull request #5159 from PrimozGodec/fix-calibration
[FIX] Calibration model: Work with numpy data
2 parents 61f7d5c + 535547e commit 7a7c635

File tree

3 files changed

+50
-8
lines changed

3 files changed

+50
-8
lines changed

Orange/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def __call__(self, data, progress_callback=None):
136136
progress_callback(0.1, "Fitting...")
137137
model = self._fit_model(data)
138138
model.used_vals = [np.unique(y).astype(int) for y in data.Y[:, None].T]
139-
model.domain = data.domain
139+
if not hasattr(model, "domain") or model.domain is None:
140+
# some models set domain themself and it should be respected
141+
# e.g. calibration learners set the base_learner's domain which
142+
# would be wrongly overwritten if we set it here for any model
143+
model.domain = data.domain
140144
model.supports_multiclass = self.supports_multiclass
141145
model.name = self.name
142146
model.original_domain = origdomain

Orange/classification/tests/test_calibration.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from Orange.base import Model
7+
from Orange.classification import LogisticRegressionLearner
78
from Orange.classification.calibration import \
89
ThresholdLearner, ThresholdClassifier, \
910
CalibratedLearner, CalibratedClassifier
@@ -65,6 +66,18 @@ def test_non_binary_base(self):
6566
base_model.domain.class_var.is_discrete = False
6667
self.assertRaises(ValueError, ThresholdClassifier, base_model, 0.5)
6768

69+
def test_np_data(self):
70+
"""
71+
Test ThresholdModel with numpy data.
72+
When passing numpy data to model they should be already
73+
transformed to models domain since model do not know how to do it.
74+
"""
75+
data = Table('heart_disease')
76+
base_learner = LogisticRegressionLearner()
77+
model = ThresholdLearner(base_learner)(data)
78+
res = model(model.data_to_model_domain(data).X)
79+
self.assertTupleEqual((len(data),), res.shape)
80+
6881

6982
class TestThresholdLearner(unittest.TestCase):
7083
@patch("Orange.evaluation.performance_curves.Curves.from_results")
@@ -169,6 +182,18 @@ def test_calibrated_probs(self):
169182
calprobs = self.model.calibrated_probs(self.probs)
170183
np.testing.assert_almost_equal(calprobs, expprobs)
171184

185+
def test_np_data(self):
186+
"""
187+
Test CalibratedClassifier with numpy data.
188+
When passing numpy data to model they should be already
189+
transformed to models domain since model do not know how to do it.
190+
"""
191+
data = Table('heart_disease')
192+
base_learner = LogisticRegressionLearner()
193+
model = CalibratedLearner(base_learner)(data)
194+
res = model(model.data_to_model_domain(data).X)
195+
self.assertTupleEqual((len(data),), res.shape)
196+
172197

173198
class TestCalibratedLearner(unittest.TestCase):
174199
@patch("Orange.classification.calibration._SigmoidCalibration.fit")
@@ -207,3 +232,7 @@ def test_fit_storage(self, test_on_training, sigmoid_fit):
207232
for call, cls_probs in zip(sigmoid_fit.call_args_list,
208233
res.probabilities[0].T):
209234
np.testing.assert_equal(call[0][0], cls_probs)
235+
236+
237+
if __name__ == "__main__":
238+
unittest.main()

Orange/tests/test_classification.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,29 @@ def test_result_shape_numpy(self):
245245
Test whether results shapes are correct when testing on numpy data
246246
"""
247247
iris = Table('iris')
248+
iris_bin = Table(
249+
Domain(
250+
iris.domain.attributes,
251+
DiscreteVariable("iris", values=["a", "b"])
252+
),
253+
iris.X[:100], iris.Y[:100]
254+
)
248255
for learner in all_learners():
249256
with self.subTest(learner.__name__):
250-
try:
251-
model = learner()(iris)
252-
except TypeError:
253-
# cannot be tested with default parameters
254-
continue
255-
transformed_iris = model.data_to_model_domain(iris)
257+
args = []
258+
if learner in (ThresholdLearner, CalibratedLearner):
259+
args = [LogisticRegressionLearner()]
260+
data = iris_bin if learner is ThresholdLearner else iris
261+
model = learner(*args)(data)
262+
transformed_iris = model.data_to_model_domain(data)
256263

257264
res = model(transformed_iris.X[0:5])
258265
self.assertTupleEqual((5,), res.shape)
259266

260267
res = model(transformed_iris.X[0:1], model.Probs)
261-
self.assertTupleEqual((1, 3), res.shape)
268+
self.assertTupleEqual(
269+
(1, len(data.domain.class_var.values)), res.shape
270+
)
262271

263272

264273
class ExpandProbabilitiesTest(unittest.TestCase):

0 commit comments

Comments
 (0)