Skip to content

Commit 94fa423

Browse files
committed
Enable numpy classification test for calibration
1 parent 0b119c3 commit 94fa423

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

Orange/tests/test_classification.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,29 @@ def test_result_shape_numpy(self):
245245
Test whether results shapes are correct when testing on numpy data
246246
"""
247247
iris = Table('iris')
248+
iris_bin = Table(
249+
Domain(
250+
iris.domain.attributes,
251+
DiscreteVariable("iris", values=["a", "b"])
252+
),
253+
iris.X[:100], iris.Y[:100]
254+
)
248255
for learner in all_learners():
249256
with self.subTest(learner.__name__):
250-
try:
251-
model = learner()(iris)
252-
except TypeError:
253-
# cannot be tested with default parameters
254-
continue
255-
transformed_iris = model.data_to_model_domain(iris)
257+
args = []
258+
if learner in (ThresholdLearner, CalibratedLearner):
259+
args = [LogisticRegressionLearner()]
260+
data = iris_bin if learner is ThresholdLearner else iris
261+
model = learner(*args)(data)
262+
transformed_iris = model.data_to_model_domain(data)
256263

257264
res = model(transformed_iris.X[0:5])
258265
self.assertTupleEqual((5,), res.shape)
259266

260267
res = model(transformed_iris.X[0:1], model.Probs)
261-
self.assertTupleEqual((1, 3), res.shape)
268+
self.assertTupleEqual(
269+
(1, len(data.domain.class_var.values)), res.shape
270+
)
262271

263272

264273
class ExpandProbabilitiesTest(unittest.TestCase):

0 commit comments

Comments
 (0)