|
4 | 4 | import numpy as np |
5 | 5 |
|
6 | 6 | from Orange.base import Model |
| 7 | +from Orange.classification import LogisticRegressionLearner |
7 | 8 | from Orange.classification.calibration import \ |
8 | 9 | ThresholdLearner, ThresholdClassifier, \ |
9 | 10 | CalibratedLearner, CalibratedClassifier |
@@ -65,6 +66,18 @@ def test_non_binary_base(self): |
65 | 66 | base_model.domain.class_var.is_discrete = False |
66 | 67 | self.assertRaises(ValueError, ThresholdClassifier, base_model, 0.5) |
67 | 68 |
|
| 69 | + @staticmethod |
| 70 | + def test_np_data(): |
| 71 | + """ |
| 72 | + Test ThresholdModel with numpy data. |
| 73 | + When passing numpy data to model they should be already |
| 74 | + transformed to models domain since model do not know how to do it. |
| 75 | + """ |
| 76 | + data = Table('heart_disease') |
| 77 | + base_learner = LogisticRegressionLearner() |
| 78 | + model = ThresholdLearner(base_learner)(data) |
| 79 | + model(model.data_to_model_domain(data).X) |
| 80 | + |
68 | 81 |
|
69 | 82 | class TestThresholdLearner(unittest.TestCase): |
70 | 83 | @patch("Orange.evaluation.performance_curves.Curves.from_results") |
@@ -166,6 +179,18 @@ def test_calibrated_probs(self): |
166 | 179 | calprobs = self.model.calibrated_probs(self.probs) |
167 | 180 | np.testing.assert_almost_equal(calprobs, expprobs) |
168 | 181 |
|
| 182 | + @staticmethod |
| 183 | + def test_np_data(): |
| 184 | + """ |
| 185 | + Test CalibratedClassifier with numpy data. |
| 186 | + When passing numpy data to model they should be already |
| 187 | + transformed to models domain since model do not know how to do it. |
| 188 | + """ |
| 189 | + data = Table('heart_disease') |
| 190 | + base_learner = LogisticRegressionLearner() |
| 191 | + model = CalibratedLearner(base_learner)(data) |
| 192 | + model(model.data_to_model_domain(data).X) |
| 193 | + |
169 | 194 |
|
170 | 195 | class TestCalibratedLearner(unittest.TestCase): |
171 | 196 | @patch("Orange.classification.calibration._SigmoidCalibration.fit") |
@@ -201,3 +226,7 @@ def test_fit_storage(self, test_on_training, sigmoid_fit): |
201 | 226 | for call, cls_probs in zip(sigmoid_fit.call_args_list, |
202 | 227 | res.probabilities[0].T): |
203 | 228 | np.testing.assert_equal(call[0][0], cls_probs) |
| 229 | + |
| 230 | + |
| 231 | +if __name__ == "__main__": |
| 232 | + unittest.main() |
0 commit comments