Skip to content

Commit 2c13d0a

Browse files
authored
Merge pull request #5847 from VesnaT/fix_nomogram_class_var
[FIX] Nomogram: Purge class_var values
2 parents 0d9ad11 + 16d475a commit 2c13d0a

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

Orange/widgets/visualize/ownomogram.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -910,8 +910,7 @@ def calculate_log_reg_coefficients(self):
910910
if not isinstance(self.classifier, LogisticRegressionClassifier):
911911
return
912912

913-
self.domain = self.reconstruct_domain(self.classifier.original_domain,
914-
self.domain)
913+
self.domain = self.reconstruct_domain(self.classifier, self.domain)
915914
self.data = self.classifier.original_data.transform(self.domain)
916915
attrs, ranges, start = self.domain.attributes, [], 0
917916
for attr in attrs:
@@ -1273,8 +1272,9 @@ def send_report(self):
12731272
self.report_plot()
12741273

12751274
@staticmethod
1276-
def reconstruct_domain(original, preprocessed):
1275+
def reconstruct_domain(classifier: Model, preprocessed: Domain) -> Domain:
12771276
# abuse dict to make "in" comparisons faster
1277+
original = classifier.original_domain
12781278
attrs = OrderedDict()
12791279
for attr in preprocessed.attributes:
12801280
cv = attr._compute_value.variable._compute_value
@@ -1284,7 +1284,13 @@ def reconstruct_domain(original, preprocessed):
12841284
continue
12851285
attrs[var] = None # we only need keys
12861286
attrs = list(attrs.keys())
1287-
return Domain(attrs, original.class_var, original.metas)
1287+
1288+
orig_clv = original.class_var
1289+
orig_data = classifier.original_data
1290+
values = (orig_clv.values[int(i)] for i in
1291+
np.unique(orig_data.get_column_view(orig_clv)[0]))
1292+
class_var = DiscreteVariable(original.class_var.name, values)
1293+
return Domain(attrs, class_var, original.metas)
12881294

12891295
@staticmethod
12901296
def get_ruler_values(start, stop, max_width, round_to_nearest=True):

Orange/widgets/visualize/tests/test_ownomogram.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from Orange.preprocess import Scale, Continuize
1515
from Orange.tests import test_filename
1616
from Orange.widgets.tests.base import WidgetTest
17+
from Orange.widgets.tests.utils import simulate
1718
from Orange.widgets.visualize.ownomogram import (
1819
OWNomogram, DiscreteFeatureItem, ContinuousFeatureItem, ProbabilitiesDotItem,
1920
MovableToolTip
@@ -293,25 +294,52 @@ def test_dots_stop_flashing(self):
293294
def test_reconstruct_domain(self):
294295
data = Table("heart_disease")
295296
cls = LogisticRegressionLearner()(data)
296-
domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain)
297+
domain = OWNomogram.reconstruct_domain(cls, cls.domain)
297298
transformed_data = cls.original_data.transform(domain)
298299
self.assertEqual(transformed_data.X.shape, data.X.shape)
299300
self.assertFalse(np.isnan(transformed_data.X[0]).any())
300301

301302
scaled_data = Scale()(data)
302303
cls = LogisticRegressionLearner()(scaled_data)
303-
domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain)
304+
domain = OWNomogram.reconstruct_domain(cls, cls.domain)
304305
transformed_data = cls.original_data.transform(domain)
305306
self.assertEqual(transformed_data.X.shape, scaled_data.X.shape)
306307
self.assertFalse(np.isnan(transformed_data.X[0]).any())
307308

308309
disc_data = Continuize()(data)
309310
cls = LogisticRegressionLearner()(disc_data)
310-
domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain)
311+
domain = OWNomogram.reconstruct_domain(cls, cls.domain)
311312
transformed_data = cls.original_data.transform(domain)
312313
self.assertEqual(transformed_data.X.shape, disc_data.X.shape)
313314
self.assertFalse(np.isnan(transformed_data.X[0]).any())
314315

316+
def test_missing_class_value(self):
317+
iris = Table("iris")
318+
iris_set_ver = iris[:100]
319+
target_cb = self.widget.controls.target_class_index
320+
321+
lr = LogisticRegressionLearner()(iris)
322+
self.send_signal(self.widget.Inputs.classifier, lr)
323+
simulate.combobox_activate_index(target_cb, 2)
324+
self.assertEqual(target_cb.currentIndex(), 2)
325+
self.assertEqual(target_cb.count(), 3)
326+
327+
lr = LogisticRegressionLearner()(iris_set_ver)
328+
self.send_signal(self.widget.Inputs.classifier, lr)
329+
self.assertEqual(target_cb.currentIndex(), 0)
330+
self.assertEqual(target_cb.count(), 2)
331+
332+
nb = NaiveBayesLearner()(iris)
333+
self.send_signal(self.widget.Inputs.classifier, nb)
334+
simulate.combobox_activate_index(target_cb, 2)
335+
self.assertEqual(target_cb.currentIndex(), 2)
336+
self.assertEqual(target_cb.count(), 3)
337+
338+
nb = NaiveBayesLearner()(iris_set_ver)
339+
self.send_signal(self.widget.Inputs.classifier, nb)
340+
self.assertEqual(target_cb.currentIndex(), 2)
341+
self.assertEqual(target_cb.count(), 3)
342+
315343

316344
if __name__ == "__main__":
317345
unittest.main()

0 commit comments

Comments
 (0)