|
1 | 1 | import unittest |
| 2 | +import numpy as np |
2 | 3 |
|
3 | 4 | from orangewidget.tests.base import WidgetTest |
4 | 5 |
|
5 | 6 | from Orange.data import Table |
6 | 7 | from Orange.preprocess import Impute |
7 | 8 |
|
8 | | -from Orange.classification.scoringsheet import ScoringSheetLearner |
9 | 9 | from Orange.widgets.model.owscoringsheet import OWScoringSheet |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class TestOWScoringSheet(WidgetTest): |
13 | | - @classmethod |
14 | | - def setUpClass(cls): |
15 | | - super().setUpClass() |
16 | | - cls.heart = Table("heart_disease") |
17 | | - cls.housing = Table("housing") |
18 | | - cls.scoring_sheet_learner = ScoringSheetLearner(20, 5, 5, None) |
19 | | - cls.scoring_sheet_model = cls.scoring_sheet_learner(cls.heart) |
20 | | - |
21 | 13 | def setUp(self): |
22 | 14 | self.widget = self.create_widget(OWScoringSheet) |
| 15 | + self.heart = Table("heart_disease") |
| 16 | + self.housing = Table("housing") |
23 | 17 |
|
24 | 18 | def test_no_data_input(self): |
25 | 19 | self.assertIsNotNone(self.get_output(self.widget.Outputs.learner)) |
@@ -64,14 +58,49 @@ def test_settings_in_model(self): |
64 | 58 |
|
65 | 59 | self.assertEqual(len(coefficients), self.widget.num_attr_after_selection) |
66 | 60 |
|
67 | | - # most often equal, but in some cases the optimizer finds fewer parameters |
68 | | - self.assertLessEqual(len(non_zero_coefficients), self.widget.num_decision_params) |
| 61 | + self.assertEqual(len(non_zero_coefficients), self.widget.num_decision_params) |
69 | 62 |
|
70 | 63 | self.assertLessEqual( |
71 | 64 | max(non_zero_coefficients, key=lambda x: abs(x)), |
72 | 65 | self.widget.max_points_per_param, |
73 | 66 | ) |
74 | 67 |
|
| 68 | + def test_model_reproducibility(self): |
| 69 | + self.widget = self.create_widget(OWScoringSheet) |
| 70 | + self.widget.num_attr_after_selection = 20 |
| 71 | + self.widget.num_decision_params = 7 |
| 72 | + self.widget.max_points_per_param = 8 |
| 73 | + self.widget.custom_features_checkbox = True |
| 74 | + self.widget.num_input_features = 4 |
| 75 | + |
| 76 | + self.widget.apply() |
| 77 | + |
| 78 | + self.send_signal(self.widget.Inputs.data, self.heart) |
| 79 | + self.wait_until_finished() |
| 80 | + model = self.get_output(self.widget.Outputs.model) |
| 81 | + |
| 82 | + coefficients = np.array( |
| 83 | + [ |
| 84 | + -8.0, 6.0, 0.0, 0.0, -3.0, 4.0, 0.0, -2.0, -1.0, |
| 85 | + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -6.0, 0.0, 0.0, 0.0, 0.0, |
| 86 | + ] |
| 87 | + ) |
| 88 | + feature_names = [ |
| 89 | + "major vessels colored=< 1", "chest pain=asymptomatic", "gender=female", |
| 90 | + "gender=male", "thal=normal", "thal=reversable defect", "rest SBP=125 - 150", |
| 91 | + "chest pain=non-anginal", "major vessels colored=1 - 2", "major vessels colored=2 - 3", |
| 92 | + "chest pain=atypical ang", "chest pain=typical ang", "rest SBP=150 - 175", |
| 93 | + "rest ECG=left vent hypertrophy", "rest ECG=normal", "ST by exercise=< 2", |
| 94 | + "rest SBP=100 - 125", "exerc ind ang=0", "exerc ind ang=1", "age=40 - 60", |
| 95 | + ] |
| 96 | + intercept = 7.0 |
| 97 | + multiplier = 3.4567159 |
| 98 | + |
| 99 | + np.testing.assert_equal(model.model.coefficients, coefficients) |
| 100 | + self.assertEqual(model.model.featureNames, feature_names) |
| 101 | + self.assertEqual(model.model.intercept, intercept) |
| 102 | + self.assertAlmostEqual(model.model.multiplier, multiplier, places=5) |
| 103 | + |
75 | 104 | def test_custom_number_input_features_information(self): |
76 | 105 | self.widget.custom_features_checkbox = True |
77 | 106 | self.widget.custom_input_features() |
|
0 commit comments