Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions Orange/widgets/visualize/ownomogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,8 +910,7 @@ def calculate_log_reg_coefficients(self):
if not isinstance(self.classifier, LogisticRegressionClassifier):
return

self.domain = self.reconstruct_domain(self.classifier.original_domain,
self.domain)
self.domain = self.reconstruct_domain(self.classifier, self.domain)
self.data = self.classifier.original_data.transform(self.domain)
attrs, ranges, start = self.domain.attributes, [], 0
for attr in attrs:
Expand Down Expand Up @@ -1273,8 +1272,9 @@ def send_report(self):
self.report_plot()

@staticmethod
def reconstruct_domain(original, preprocessed):
def reconstruct_domain(classifier: Model, preprocessed: Domain) -> Domain:
# abuse dict to make "in" comparisons faster
original = classifier.original_domain
attrs = OrderedDict()
for attr in preprocessed.attributes:
cv = attr._compute_value.variable._compute_value
Expand All @@ -1284,7 +1284,13 @@ def reconstruct_domain(original, preprocessed):
continue
attrs[var] = None # we only need keys
attrs = list(attrs.keys())
return Domain(attrs, original.class_var, original.metas)

orig_clv = original.class_var
orig_data = classifier.original_data
values = (orig_clv.values[int(i)] for i in
np.unique(orig_data.get_column_view(orig_clv)[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.unique is $n \log n$. I think it would be better to have something like

mask = np.zeros(len(orig_clv.values), dtype=bool)
column = orig_data.get_column_view(orig_clv)[0]
mask[column[np.isfinite(column)].astype(int)] = True
values = tuple(np.array(orig_clv.values)[mask])

(I haven't tested this code, it's just an idea.)

Copy link
Contributor Author

@VesnaT VesnaT Feb 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the unique. For me, it is easier to read the code.

I'd leave it as is, unless you think it's too time consuming.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like what Janez did, but perhaps we could pack it into a function. Looking for values of discrete variables that are used in some column seems like something common.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If @VesnaT finds the original better, we can keep it.

However, @markotoplak's idea to have a function make sense. Where to put it, how to name it? Hm, let's put it to Orange.prepocess.remove and name it remove_unused_values. Oh, no, wait, it's already there.

It is understandable that @VesnaT and I did not know about this function. Oh, no, wait, @VesnaT wrote it and I did some work on it two years ago.

class_var = DiscreteVariable(original.class_var.name, values)
return Domain(attrs, class_var, original.metas)

@staticmethod
def get_ruler_values(start, stop, max_width, round_to_nearest=True):
Expand Down
34 changes: 31 additions & 3 deletions Orange/widgets/visualize/tests/test_ownomogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from Orange.preprocess import Scale, Continuize
from Orange.tests import test_filename
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.utils import simulate
from Orange.widgets.visualize.ownomogram import (
OWNomogram, DiscreteFeatureItem, ContinuousFeatureItem, ProbabilitiesDotItem,
MovableToolTip
Expand Down Expand Up @@ -293,25 +294,52 @@ def test_dots_stop_flashing(self):
def test_reconstruct_domain(self):
data = Table("heart_disease")
cls = LogisticRegressionLearner()(data)
domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain)
domain = OWNomogram.reconstruct_domain(cls, cls.domain)
transformed_data = cls.original_data.transform(domain)
self.assertEqual(transformed_data.X.shape, data.X.shape)
self.assertFalse(np.isnan(transformed_data.X[0]).any())

scaled_data = Scale()(data)
cls = LogisticRegressionLearner()(scaled_data)
domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain)
domain = OWNomogram.reconstruct_domain(cls, cls.domain)
transformed_data = cls.original_data.transform(domain)
self.assertEqual(transformed_data.X.shape, scaled_data.X.shape)
self.assertFalse(np.isnan(transformed_data.X[0]).any())

disc_data = Continuize()(data)
cls = LogisticRegressionLearner()(disc_data)
domain = OWNomogram.reconstruct_domain(cls.original_domain, cls.domain)
domain = OWNomogram.reconstruct_domain(cls, cls.domain)
transformed_data = cls.original_data.transform(domain)
self.assertEqual(transformed_data.X.shape, disc_data.X.shape)
self.assertFalse(np.isnan(transformed_data.X[0]).any())

def test_missing_class_value(self):
iris = Table("iris")
iris_set_ver = iris[:100]
target_cb = self.widget.controls.target_class_index

lr = LogisticRegressionLearner()(iris)
self.send_signal(self.widget.Inputs.classifier, lr)
simulate.combobox_activate_index(target_cb, 2)
self.assertEqual(target_cb.currentIndex(), 2)
self.assertEqual(target_cb.count(), 3)

lr = LogisticRegressionLearner()(iris_set_ver)
self.send_signal(self.widget.Inputs.classifier, lr)
self.assertEqual(target_cb.currentIndex(), 0)
self.assertEqual(target_cb.count(), 2)

nb = NaiveBayesLearner()(iris)
self.send_signal(self.widget.Inputs.classifier, nb)
simulate.combobox_activate_index(target_cb, 2)
self.assertEqual(target_cb.currentIndex(), 2)
self.assertEqual(target_cb.count(), 3)

nb = NaiveBayesLearner()(iris_set_ver)
self.send_signal(self.widget.Inputs.classifier, nb)
self.assertEqual(target_cb.currentIndex(), 2)
self.assertEqual(target_cb.count(), 3)


if __name__ == "__main__":
unittest.main()