@@ -245,20 +245,29 @@ def test_result_shape_numpy(self):
245245 Test whether results shapes are correct when testing on numpy data
246246 """
247247 iris = Table ('iris' )
248+ iris_bin = Table (
249+ Domain (
250+ iris .domain .attributes ,
251+ DiscreteVariable ("iris" , values = ["a" , "b" ])
252+ ),
253+ iris .X [:100 ], iris .Y [:100 ]
254+ )
248255 for learner in all_learners ():
249256 with self .subTest (learner .__name__ ):
250- try :
251- model = learner ()( iris )
252- except TypeError :
253- # cannot be tested with default parameters
254- continue
255- transformed_iris = model .data_to_model_domain (iris )
257+ args = []
258+ if learner in ( ThresholdLearner , CalibratedLearner ):
259+ args = [ LogisticRegressionLearner ()]
260+ data = iris_bin if learner is ThresholdLearner else iris
261+ model = learner ( * args )( data )
262+ transformed_iris = model .data_to_model_domain (data )
256263
257264 res = model (transformed_iris .X [0 :5 ])
258265 self .assertTupleEqual ((5 ,), res .shape )
259266
260267 res = model (transformed_iris .X [0 :1 ], model .Probs )
261- self .assertTupleEqual ((1 , 3 ), res .shape )
268+ self .assertTupleEqual (
269+ (1 , len (data .domain .class_var .values )), res .shape
270+ )
262271
263272
264273class ExpandProbabilitiesTest (unittest .TestCase ):
0 commit comments