|
4 | 4 | import numpy as np |
5 | 5 |
|
6 | 6 | from Orange.base import Model |
| 7 | +from Orange.classification import LogisticRegressionLearner |
7 | 8 | from Orange.classification.calibration import \ |
8 | 9 | ThresholdLearner, ThresholdClassifier, \ |
9 | 10 | CalibratedLearner, CalibratedClassifier |
@@ -65,6 +66,18 @@ def test_non_binary_base(self): |
65 | 66 | base_model.domain.class_var.is_discrete = False |
66 | 67 | self.assertRaises(ValueError, ThresholdClassifier, base_model, 0.5) |
67 | 68 |
|
| 69 | + def test_np_data(self): |
| 70 | + """ |
| 71 | + Test ThresholdModel with numpy data. |
| 72 | + When passing numpy data to model they should be already |
| 73 | + transformed to models domain since model do not know how to do it. |
| 74 | + """ |
| 75 | + data = Table('heart_disease') |
| 76 | + base_learner = LogisticRegressionLearner() |
| 77 | + model = ThresholdLearner(base_learner)(data) |
| 78 | + res = model(model.data_to_model_domain(data).X) |
| 79 | + self.assertTupleEqual((len(data),), res.shape) |
| 80 | + |
68 | 81 |
|
69 | 82 | class TestThresholdLearner(unittest.TestCase): |
70 | 83 | @patch("Orange.evaluation.performance_curves.Curves.from_results") |
@@ -169,6 +182,18 @@ def test_calibrated_probs(self): |
169 | 182 | calprobs = self.model.calibrated_probs(self.probs) |
170 | 183 | np.testing.assert_almost_equal(calprobs, expprobs) |
171 | 184 |
|
| 185 | + def test_np_data(self): |
| 186 | + """ |
| 187 | + Test CalibratedClassifier with numpy data. |
| 188 | + When passing numpy data to model they should be already |
| 189 | + transformed to models domain since model do not know how to do it. |
| 190 | + """ |
| 191 | + data = Table('heart_disease') |
| 192 | + base_learner = LogisticRegressionLearner() |
| 193 | + model = CalibratedLearner(base_learner)(data) |
| 194 | + res = model(model.data_to_model_domain(data).X) |
| 195 | + self.assertTupleEqual((len(data),), res.shape) |
| 196 | + |
172 | 197 |
|
173 | 198 | class TestCalibratedLearner(unittest.TestCase): |
174 | 199 | @patch("Orange.classification.calibration._SigmoidCalibration.fit") |
@@ -207,3 +232,7 @@ def test_fit_storage(self, test_on_training, sigmoid_fit): |
207 | 232 | for call, cls_probs in zip(sigmoid_fit.call_args_list, |
208 | 233 | res.probabilities[0].T): |
209 | 234 | np.testing.assert_equal(call[0][0], cls_probs) |
| 235 | + |
| 236 | + |
| 237 | +if __name__ == "__main__": |
| 238 | + unittest.main() |
0 commit comments