Skip to content

Commit f7b0268

Browse files
committed
OWPCA: Output ApplyDomain preprocessor
1 parent 62bd4b7 commit f7b0268

File tree

4 files changed

+52
-1
lines changed

4 files changed

+52
-1
lines changed

Orange/preprocess/preprocess.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,21 @@ def transform(var):
488488
return data.transform(domain)
489489

490490

491+
class ApplyDomain(Preprocess):
492+
def __init__(self, domain, name, max_components=None):
493+
self._domain = domain
494+
self._name = name
495+
self._max_components = max_components
496+
497+
def __call__(self, data):
498+
transformed = data.transform(self._domain)
499+
return transformed[:, :self._max_components] \
500+
if self._max_components is not None else transformed
501+
502+
def __str__(self):
503+
return self._name
504+
505+
491506
class PreprocessorList(Preprocess):
492507
"""
493508
Store a list of preprocessors and on call apply them to the dataset.

Orange/widgets/data/tests/test_owtransform.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
# pylint: disable=missing-docstring
33
from Orange.data import Table
44
from Orange.preprocess import Discretize
5+
from Orange.preprocess.preprocess import Preprocess
56
from Orange.widgets.data.owtransform import OWTransform
67
from Orange.widgets.tests.base import WidgetTest
8+
from Orange.widgets.unsupervised.owpca import OWPCA
79

810

911
class TestOWTransform(WidgetTest):
@@ -61,6 +63,19 @@ def test_output(self):
6163
self.widget.preprocessor_label.text())
6264
self.assertEqual("", self.widget.output_label.text())
6365

66+
def test_input_pca_preprocessor(self):
67+
owpca = self.create_widget(OWPCA)
68+
self.send_signal(owpca.Inputs.data, self.data, widget=owpca)
69+
owpca.components_spin.setValue(2)
70+
pp = self.get_output(owpca.Outputs.preprocessor, widget=owpca)
71+
self.assertIsNotNone(pp, Preprocess)
72+
73+
self.send_signal(self.widget.Inputs.data, self.data)
74+
self.send_signal(self.widget.Inputs.preprocessor, pp)
75+
output = self.get_output(self.widget.Outputs.transformed_data)
76+
self.assertIsInstance(output, Table)
77+
self.assertEqual(output.X.shape, (len(self.data), 2))
78+
6479
def test_send_report(self):
6580
self.send_signal(self.widget.Inputs.data, self.data)
6681
self.widget.report_button.click()

Orange/widgets/unsupervised/owpca.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from Orange.data import Table, Domain, StringVariable, ContinuousVariable
1111
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
1212
from Orange.preprocess import Normalize
13+
from Orange.preprocess.preprocess import Preprocess, ApplyDomain
1314
from Orange.projection import PCA, TruncatedSVD
1415
from Orange.widgets import widget, gui, settings
1516
from Orange.widgets.widget import Input, Output
@@ -44,6 +45,7 @@ class Outputs:
4445
transformed_data = Output("Transformed data", Table)
4546
components = Output("Components", Table)
4647
pca = Output("PCA", PCA, dynamic=False)
48+
preprocessor = Output("Preprocessor", Preprocess)
4749

4850
settingsHandler = settings.DomainContextHandler()
4951

@@ -290,6 +292,7 @@ def clear_outputs(self):
290292
self.Outputs.transformed_data.send(None)
291293
self.Outputs.components.send(None)
292294
self.Outputs.pca.send(self._pca_projector)
295+
self.Outputs.preprocessor.send(None)
293296

294297
def get_model(self):
295298
if self.rpca is None:
@@ -455,7 +458,7 @@ def _update_axis(self):
455458
axis.setTicks([[(i, str(i+1)) for i in range(0, p, d)]])
456459

457460
def commit(self):
458-
transformed = components = None
461+
transformed = components = pp = None
459462
if self._pca is not None:
460463
if self._transformed is None:
461464
# Compute the full transform (MAX_COMPONENTS components) only once.
@@ -479,10 +482,14 @@ def commit(self):
479482
metas=metas)
480483
components.name = 'components'
481484

485+
domain = self._pca_projector(self.data).domain
486+
pp = ApplyDomain(domain, "PCA", self.ncomponents)
487+
482488
self._pca_projector.component = self.ncomponents
483489
self.Outputs.transformed_data.send(transformed)
484490
self.Outputs.components.send(components)
485491
self.Outputs.pca.send(self._pca_projector)
492+
self.Outputs.preprocessor.send(pp)
486493

487494
def send_report(self):
488495
if self.data is None:

Orange/widgets/unsupervised/tests/test_owpca.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import scipy.sparse as sp
55

66
from Orange.data import Table, Domain, ContinuousVariable, TimeVariable
7+
from Orange.preprocess.preprocess import Preprocess
78
from Orange.widgets.tests.base import WidgetTest
89
from Orange.widgets.unsupervised.owpca import OWPCA, DECOMPOSITIONS
910

@@ -131,3 +132,16 @@ def test_do_not_mask_features(self):
131132
self.widget.set_data(data)
132133
ndata = Table("iris.tab")
133134
self.assertEqual(data.domain[0], ndata.domain[0])
135+
136+
def test_output_preprocessor(self):
137+
data = Table("iris")
138+
self.send_signal(self.widget.Inputs.data, data)
139+
pp = self.get_output(self.widget.Outputs.preprocessor)
140+
self.assertIsInstance(pp, Preprocess)
141+
transformed_data = pp(data[::10])
142+
self.assertIsInstance(transformed_data, Table)
143+
self.assertEqual(transformed_data.X.shape, (15, 2))
144+
output = self.get_output(self.widget.Outputs.transformed_data)
145+
np.testing.assert_array_equal(transformed_data.X, output.X[::10])
146+
self.assertEqual([a.name for a in transformed_data.domain.attributes],
147+
[m.name for m in output.domain.attributes])

0 commit comments

Comments
 (0)