Skip to content

Commit 509e41d

Browse files
committed
PCA - Output instance of table subclass when instance of table subclass on input
1 parent f531854 commit 509e41d

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

Orange/widgets/unsupervised/owpca.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
1212
from Orange.preprocess import preprocess
1313
from Orange.projection import PCA
14-
from Orange.widgets import widget, gui, settings
14+
from Orange.widgets import widget, gui
15+
from Orange.widgets.utils.annotated_data import add_columns
1516
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin
1617
from Orange.widgets.utils.slidergraph import SliderGraph
1718
from Orange.widgets.utils.widgetpreview import WidgetPreview
@@ -38,12 +39,12 @@ class Outputs:
3839
components = Output("Components", Table)
3940
pca = Output("PCA", PCA, dynamic=False)
4041

41-
ncomponents = settings.Setting(2)
42-
variance_covered = settings.Setting(100)
43-
auto_commit = settings.Setting(True)
44-
normalize = settings.Setting(True)
45-
maxp = settings.Setting(20)
46-
axis_labels = settings.Setting(10)
42+
ncomponents = Setting(2)
43+
variance_covered = Setting(100)
44+
auto_commit = Setting(True)
45+
normalize = Setting(True)
46+
maxp = Setting(20)
47+
axis_labels = Setting(10)
4748

4849
graph_name = "plot.plotItem" # QGraphicsView (pg.PlotWidget -> SliderGraph)
4950

@@ -222,8 +223,7 @@ def _setup_plot(self):
222223
self._update_axis()
223224

224225
def _on_cut_changed(self, components):
225-
if components == self.ncomponents \
226-
or self.ncomponents == 0:
226+
if self.ncomponents in (components, 0):
227227
return
228228

229229
self.ncomponents = components
@@ -333,9 +333,9 @@ def commit(self):
333333
proposed = [a.name for a in self._pca.orig_domain.attributes]
334334
meta_name = get_unique_names(proposed, 'components')
335335
meta_vars = [StringVariable(name=meta_name)]
336-
metas = numpy.array([['PC{}'.format(i + 1)
337-
for i in range(self.ncomponents)]],
338-
dtype=object).T
336+
metas = numpy.array(
337+
[[f"PC{i + 1}"for i in range(self.ncomponents)]], dtype=object
338+
).T
339339
if self._variance_ratio is not None:
340340
variance_name = get_unique_names(proposed, "variance")
341341
meta_vars.append(ContinuousVariable(variance_name))
@@ -351,14 +351,10 @@ def commit(self):
351351
metas=metas)
352352
components.name = 'components'
353353

354-
data_dom = Domain(
355-
self.data.domain.attributes,
356-
self.data.domain.class_vars,
357-
self.data.domain.metas + domain.attributes)
358-
data = Table.from_numpy(
359-
data_dom, self.data.X, self.data.Y,
360-
numpy.hstack((self.data.metas, transformed.X)),
361-
ids=self.data.ids)
354+
data_dom = add_columns(self.data.domain, metas=domain.attributes)
355+
data = self.data.transform(data_dom)
356+
with data.unlocked(data.metas):
357+
data[:, domain.attributes] = transformed.X
362358

363359
self.Outputs.transformed_data.send(transformed)
364360
self.Outputs.components.send(components)
@@ -371,7 +367,7 @@ def send_report(self):
371367
self.report_items((
372368
("Normalize data", bool_str(self.normalize)),
373369
("Selected components", self.ncomponents),
374-
("Explained variance", "{:.3f} %".format(self.variance_covered))
370+
("Explained variance", f"{self.variance_covered:.3f} %")
375371
))
376372
self.report_plot()
377373

Orange/widgets/unsupervised/tests/test_owpca.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,21 @@ def test_output_data(self):
277277
output = self.get_output(widget.Outputs.data)
278278
self.assertIsNone(output)
279279

280+
def test_table_subclass(self):
281+
"""
282+
When input table is instance of Table's subclass (e.g. Corpus) resulting
283+
tables should also be an instance subclasses
284+
"""
285+
class TableSub(Table):
286+
pass
287+
288+
table_subclass = TableSub(self.iris)
289+
self.send_signal(self.widget.Inputs.data, table_subclass)
290+
data_out = self.get_output(self.widget.Outputs.data)
291+
trans_data_out = self.get_output(self.widget.Outputs.transformed_data)
292+
self.assertIsInstance(data_out, TableSub)
293+
self.assertIsInstance(trans_data_out, TableSub)
294+
280295

281296
if __name__ == "__main__":
282297
unittest.main()

0 commit comments

Comments
 (0)