44import numpy as np
55
66from Orange .base import Model
7+ from Orange .classification import LogisticRegressionLearner
78from Orange .classification .calibration import \
89 ThresholdLearner , ThresholdClassifier , \
910 CalibratedLearner , CalibratedClassifier
@@ -65,6 +66,18 @@ def test_non_binary_base(self):
6566 base_model .domain .class_var .is_discrete = False
6667 self .assertRaises (ValueError , ThresholdClassifier , base_model , 0.5 )
6768
69+ @staticmethod
70+ def test_np_data ():
71+ """
72+ Test ThresholdModel with numpy data.
73+ When passing numpy data to model they should be already
74+ transformed to models domain since model do not know how to do it.
75+ """
76+ data = Table ('heart_disease' )
77+ base_learner = LogisticRegressionLearner ()
78+ model = ThresholdLearner (base_learner )(data )
79+ model (model .data_to_model_domain (data ).X )
80+
6881
6982class TestThresholdLearner (unittest .TestCase ):
7083 @patch ("Orange.evaluation.performance_curves.Curves.from_results" )
@@ -166,6 +179,18 @@ def test_calibrated_probs(self):
166179 calprobs = self .model .calibrated_probs (self .probs )
167180 np .testing .assert_almost_equal (calprobs , expprobs )
168181
182+ @staticmethod
183+ def test_np_data ():
184+ """
185+ Test CalibratedClassifier with numpy data.
186+ When passing numpy data to model they should be already
187+ transformed to models domain since model do not know how to do it.
188+ """
189+ data = Table ('heart_disease' )
190+ base_learner = LogisticRegressionLearner ()
191+ model = CalibratedLearner (base_learner )(data )
192+ model (model .data_to_model_domain (data ).X )
193+
169194
170195class TestCalibratedLearner (unittest .TestCase ):
171196 @patch ("Orange.classification.calibration._SigmoidCalibration.fit" )
@@ -201,3 +226,7 @@ def test_fit_storage(self, test_on_training, sigmoid_fit):
201226 for call , cls_probs in zip (sigmoid_fit .call_args_list ,
202227 res .probabilities [0 ].T ):
203228 np .testing .assert_equal (call [0 ][0 ], cls_probs )
229+
230+
231+ if __name__ == "__main__" :
232+ unittest .main ()
0 commit comments