Skip to content

Commit fd59605

Browse files
authored
Merge pull request #5168 from JakaKokosar/enable_classification_tests
[FIX] Enable classification tests
2 parents 919dc35 + 458e866 commit fd59605

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

Orange/classification/tests/__init__.py

Whitespace-only changes.

Orange/classification/tests/test_base.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
from tempfile import TemporaryDirectory
31
import unittest
42

53
import numpy as np
@@ -12,17 +10,21 @@ class TestModelMapping(unittest.TestCase):
1210
@classmethod
1311
def setUpClass(cls):
1412
cls.iris = iris = Table("iris")
15-
with TemporaryDirectory() as tempdir:
16-
tables = []
17-
x = np.vstack((iris.X[:50], iris.X[100:]))
18-
y = np.hstack((iris.Y[:50], iris.Y[100:]))
19-
for i, data in enumerate([iris[50:],
20-
Table.from_numpy(iris.domain, x, y),
21-
iris[:100]]):
22-
23-
name = os.path.join(tempdir, f"no{i}.tab")
24-
data.save(name)
25-
tables.append(Table(name))
13+
14+
tables = []
15+
ix = iris.X
16+
y = np.hstack((np.zeros(50), np.ones(50)))
17+
attrs = cls.iris.domain.attributes
18+
classes = cls.iris.domain.class_var.values
19+
for i, x in enumerate([ix[50:],
20+
np.vstack((ix[:50], ix[100:])),
21+
ix[:100]]):
22+
class_var = DiscreteVariable(
23+
"iris",
24+
values=tuple(n for j, n in enumerate(classes) if j != i))
25+
domain = Domain(attrs, class_var)
26+
tables.append(Table.from_numpy(domain, x, y))
27+
# pylint: disable=unbalanced-tuple-unpacking
2628
cls.iris0, cls.iris1, cls.iris2 = tables
2729

2830
def test_larger_model(self):

Orange/classification/tests/test_calibration.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,22 @@ def test_fit_storage(self, test_on_training, curves_from_results):
7979
model.domain.class_var.values = ("a", "b")
8080
data = Table("heart_disease")
8181
learner = Mock()
82-
test_on_training.return_value = res = Mock()
82+
test_on_training.return_value = tot = Mock()
83+
res = Mock()
8384
res.models = np.array([[model]])
84-
test_on_training.return_value = res
85+
tot.return_value = res
8586

8687
thresh_learner = ThresholdLearner(
8788
base_learner=learner,
8889
threshold_criterion=ThresholdLearner.OptimizeCA)
8990
thresh_model = thresh_learner(data)
9091
self.assertEqual(thresh_model.threshold, 0.15)
91-
args, kwargs = test_on_training.call_args
92+
args, _ = tot.call_args # pylint: disable=unpacking-non-sequence
9293
self.assertEqual(len(args), 2)
9394
self.assertIs(args[0], data)
9495
self.assertIs(args[1][0], learner)
96+
97+
_, kwargs = test_on_training.call_args
9598
self.assertEqual(len(args[1]), 1)
9699
self.assertEqual(kwargs, {"store_models": 1})
97100

@@ -178,10 +181,11 @@ def test_fit_storage(self, test_on_training, sigmoid_fit):
178181
model.domain.class_var.is_discrete = True
179182
model.domain.class_var.values = ("a", "b")
180183

181-
test_on_training.return_value = res = Mock()
184+
test_on_training.return_value = tot = Mock()
185+
res = Mock()
182186
res.models = np.array([[model]])
183187
res.probabilities = np.arange(20, dtype=float).reshape(1, 5, 4)
184-
test_on_training.return_value = res
188+
tot.return_value = res
185189

186190
sigmoid_fit.return_value = Mock()
187191

@@ -191,11 +195,13 @@ def test_fit_storage(self, test_on_training, sigmoid_fit):
191195

192196
self.assertIs(cal_model.base_model, model)
193197
self.assertEqual(cal_model.calibrators, [sigmoid_fit.return_value] * 4)
194-
args, kwargs = test_on_training.call_args
198+
args, _ = tot.call_args # pylint: disable=unpacking-non-sequence
195199
self.assertEqual(len(args), 2)
196200
self.assertIs(args[0], data)
197201
self.assertIs(args[1][0], learner)
198202
self.assertEqual(len(args[1]), 1)
203+
204+
_, kwargs = test_on_training.call_args
199205
self.assertEqual(kwargs, {"store_models": 1})
200206

201207
for call, cls_probs in zip(sigmoid_fit.call_args_list,

0 commit comments

Comments
 (0)