diff --git a/Orange/widgets/visualize/ownomogram.py b/Orange/widgets/visualize/ownomogram.py index d41d9ea0a88..c1c24b6007d 100644 --- a/Orange/widgets/visualize/ownomogram.py +++ b/Orange/widgets/visualize/ownomogram.py @@ -910,8 +910,7 @@ def calculate_log_reg_coefficients(self): if not isinstance(self.classifier, LogisticRegressionClassifier): return - self.domain = self.reconstruct_domain(self.classifier.original_domain, - self.domain) + self.domain = self.reconstruct_domain(self.classifier, self.domain) self.data = self.classifier.original_data.transform(self.domain) attrs, ranges, start = self.domain.attributes, [], 0 for attr in attrs: @@ -1273,8 +1272,9 @@ def send_report(self): self.report_plot() @staticmethod - def reconstruct_domain(original, preprocessed): + def reconstruct_domain(classifier: Model, preprocessed: Domain) -> Domain: # abuse dict to make "in" comparisons faster + original = classifier.original_domain attrs = OrderedDict() for attr in preprocessed.attributes: cv = attr._compute_value.variable._compute_value @@ -1284,7 +1284,13 @@ def reconstruct_domain(original, preprocessed): continue attrs[var] = None # we only need keys attrs = list(attrs.keys()) - return Domain(attrs, original.class_var, original.metas) + + orig_clv = original.class_var + orig_data = classifier.original_data + values = (orig_clv.values[int(i)] for i in + np.unique(orig_data.get_column_view(orig_clv)[0])) + class_var = DiscreteVariable(original.class_var.name, values) + return Domain(attrs, class_var, original.metas) @staticmethod def get_ruler_values(start, stop, max_width, round_to_nearest=True): diff --git a/Orange/widgets/visualize/tests/test_ownomogram.py b/Orange/widgets/visualize/tests/test_ownomogram.py index 538dfb0e5bc..87bb0e07c4b 100644 --- a/Orange/widgets/visualize/tests/test_ownomogram.py +++ b/Orange/widgets/visualize/tests/test_ownomogram.py @@ -14,6 +14,7 @@ from Orange.preprocess import Scale, Continuize from Orange.tests import test_filename from Orange.widgets.tests.base import WidgetTest +from Orange.widgets.tests.utils import simulate from Orange.widgets.visualize.ownomogram import ( OWNomogram, DiscreteFeatureItem, ContinuousFeatureItem, ProbabilitiesDotItem, MovableToolTip @@ -293,25 +294,52 @@ def test_dots_stop_flashing(self): def test_reconstruct_domain(self): data = Table("heart_disease") cls = LogisticRegressionLearner()(data) - domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain) + domain = OWNomogram.reconstruct_domain(cls, cls.domain) transformed_data = cls.original_data.transform(domain) self.assertEqual(transformed_data.X.shape, data.X.shape) self.assertFalse(np.isnan(transformed_data.X[0]).any()) scaled_data = Scale()(data) cls = LogisticRegressionLearner()(scaled_data) - domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain) + domain = OWNomogram.reconstruct_domain(cls, cls.domain) transformed_data = cls.original_data.transform(domain) self.assertEqual(transformed_data.X.shape, scaled_data.X.shape) self.assertFalse(np.isnan(transformed_data.X[0]).any()) disc_data = Continuize()(data) cls = LogisticRegressionLearner()(disc_data) - domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain) + domain = OWNomogram.reconstruct_domain(cls, cls.domain) transformed_data = cls.original_data.transform(domain) self.assertEqual(transformed_data.X.shape, disc_data.X.shape) self.assertFalse(np.isnan(transformed_data.X[0]).any()) + def test_missing_class_value(self): + iris = Table("iris") + iris_set_ver = iris[:100] + target_cb = self.widget.controls.target_class_index + + lr = LogisticRegressionLearner()(iris) + self.send_signal(self.widget.Inputs.classifier, lr) + simulate.combobox_activate_index(target_cb, 2) + self.assertEqual(target_cb.currentIndex(), 2) + self.assertEqual(target_cb.count(), 3) + + lr = LogisticRegressionLearner()(iris_set_ver) + self.send_signal(self.widget.Inputs.classifier, lr) + self.assertEqual(target_cb.currentIndex(), 0) + self.assertEqual(target_cb.count(), 2) + + nb = NaiveBayesLearner()(iris) + self.send_signal(self.widget.Inputs.classifier, nb) + simulate.combobox_activate_index(target_cb, 2) + self.assertEqual(target_cb.currentIndex(), 2) + self.assertEqual(target_cb.count(), 3) + + nb = NaiveBayesLearner()(iris_set_ver) + self.send_signal(self.widget.Inputs.classifier, nb) + self.assertEqual(target_cb.currentIndex(), 2) + self.assertEqual(target_cb.count(), 3) + if __name__ == "__main__": unittest.main()