Skip to content

Commit 1725baa

Browse files
authored
Merge pull request #6536 from PrimozGodec/fix-pca
[FIX] PCA - Output instance of table subclass when instance of table subclass on input
2 parents 80f3cc9 + 981317c commit 1725baa

File tree

2 files changed

+37
-27
lines changed

2 files changed

+37
-27
lines changed

Orange/widgets/unsupervised/owpca.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from AnyQt.QtCore import Qt
66

77
from orangewidget.report import bool_str
8+
from orangewidget.settings import Setting
89

910
from Orange.data import Table, Domain, StringVariable, ContinuousVariable
1011
from Orange.data.util import get_unique_names
1112
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
1213
from Orange.preprocess import preprocess
1314
from Orange.projection import PCA
14-
from Orange.widgets import widget, gui, settings
15+
from Orange.widgets import widget, gui
16+
from Orange.widgets.utils.annotated_data import add_columns
1517
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin
1618
from Orange.widgets.utils.slidergraph import SliderGraph
1719
from Orange.widgets.utils.widgetpreview import WidgetPreview
@@ -38,12 +40,12 @@ class Outputs:
3840
components = Output("Components", Table)
3941
pca = Output("PCA", PCA, dynamic=False)
4042

41-
ncomponents = settings.Setting(2)
42-
variance_covered = settings.Setting(100)
43-
auto_commit = settings.Setting(True)
44-
normalize = settings.Setting(True)
45-
maxp = settings.Setting(20)
46-
axis_labels = settings.Setting(10)
43+
ncomponents = Setting(2)
44+
variance_covered = Setting(100)
45+
auto_commit = Setting(True)
46+
normalize = Setting(True)
47+
maxp = Setting(20)
48+
axis_labels = Setting(10)
4749

4850
graph_name = "plot.plotItem" # QGraphicsView (pg.PlotWidget -> SliderGraph)
4951

@@ -222,8 +224,7 @@ def _setup_plot(self):
222224
self._update_axis()
223225

224226
def _on_cut_changed(self, components):
225-
if components == self.ncomponents \
226-
or self.ncomponents == 0:
227+
if self.ncomponents in (components, 0):
227228
return
228229

229230
self.ncomponents = components
@@ -333,9 +334,9 @@ def commit(self):
333334
proposed = [a.name for a in self._pca.orig_domain.attributes]
334335
meta_name = get_unique_names(proposed, 'components')
335336
meta_vars = [StringVariable(name=meta_name)]
336-
metas = numpy.array([['PC{}'.format(i + 1)
337-
for i in range(self.ncomponents)]],
338-
dtype=object).T
337+
metas = numpy.array(
338+
[[f"PC{i + 1}"for i in range(self.ncomponents)]], dtype=object
339+
).T
339340
if self._variance_ratio is not None:
340341
variance_name = get_unique_names(proposed, "variance")
341342
meta_vars.append(ContinuousVariable(variance_name))
@@ -351,14 +352,8 @@ def commit(self):
351352
metas=metas)
352353
components.name = 'components'
353354

354-
data_dom = Domain(
355-
self.data.domain.attributes,
356-
self.data.domain.class_vars,
357-
self.data.domain.metas + domain.attributes)
358-
data = Table.from_numpy(
359-
data_dom, self.data.X, self.data.Y,
360-
numpy.hstack((self.data.metas, transformed.X)),
361-
ids=self.data.ids)
355+
data_dom = add_columns(self.data.domain, metas=domain.attributes)
356+
data = self.data.transform(data_dom)
362357

363358
self.Outputs.transformed_data.send(transformed)
364359
self.Outputs.components.send(components)
@@ -371,7 +366,7 @@ def send_report(self):
371366
self.report_items((
372367
("Normalize data", bool_str(self.normalize)),
373368
("Selected components", self.ncomponents),
374-
("Explained variance", "{:.3f} %".format(self.variance_covered))
369+
("Explained variance", f"{self.variance_covered:.3f} %")
375370
))
376371
self.report_plot()
377372

Orange/widgets/unsupervised/tests/test_owpca.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from unittest.mock import patch, Mock
55

66
import numpy as np
7+
from sklearn.utils import check_random_state
8+
from sklearn.utils.extmath import svd_flip
79

810
from Orange.data import Table, Domain, ContinuousVariable, TimeVariable
911
from Orange.preprocess import preprocess
1012
from Orange.widgets.tests.base import WidgetTest
1113
from Orange.widgets.tests.utils import table_dense_sparse, possible_duplicate_table
1214
from Orange.widgets.unsupervised.owpca import OWPCA
1315
from Orange.tests import test_filename
14-
from sklearn.utils import check_random_state
15-
from sklearn.utils.extmath import svd_flip
1616

1717

1818
class TestOWPCA(WidgetTest):
@@ -63,19 +63,19 @@ def test_limit_components(self):
6363
self.widget._setup_plot() # pylint: disable=protected-access
6464

6565
def test_migrate_settings_limits_components(self):
66-
settings = dict(ncomponents=10)
66+
settings = {"ncomponents": 10}
6767
OWPCA.migrate_settings(settings, 0)
6868
self.assertEqual(settings['ncomponents'], 10)
69-
settings = dict(ncomponents=101)
69+
settings = {"ncomponents": 101}
7070
OWPCA.migrate_settings(settings, 0)
7171
self.assertEqual(settings['ncomponents'], 100)
7272

7373
def test_migrate_settings_changes_variance_covered_to_int(self):
74-
settings = dict(variance_covered=17.5)
74+
settings = {"variance_covered": 17.5}
7575
OWPCA.migrate_settings(settings, 0)
7676
self.assertEqual(settings["variance_covered"], 17)
7777

78-
settings = dict(variance_covered=float('nan'))
78+
settings = {"variance_covered": float('nan')}
7979
OWPCA.migrate_settings(settings, 0)
8080
self.assertEqual(settings["variance_covered"], 100)
8181

@@ -277,6 +277,21 @@ def test_output_data(self):
277277
output = self.get_output(widget.Outputs.data)
278278
self.assertIsNone(output)
279279

280+
def test_table_subclass(self):
281+
"""
282+
When input table is instance of Table's subclass (e.g. Corpus) resulting
283+
tables should also be an instance subclasses
284+
"""
285+
class TableSub(Table): # pylint: disable=abstract-method
286+
pass
287+
288+
table_subclass = TableSub(self.iris)
289+
self.send_signal(self.widget.Inputs.data, table_subclass)
290+
data_out = self.get_output(self.widget.Outputs.data)
291+
trans_data_out = self.get_output(self.widget.Outputs.transformed_data)
292+
self.assertIsInstance(data_out, TableSub)
293+
self.assertIsInstance(trans_data_out, TableSub)
294+
280295

281296
if __name__ == "__main__":
282297
unittest.main()

0 commit comments

Comments
 (0)