diff --git a/Orange/widgets/unsupervised/owtsne.py b/Orange/widgets/unsupervised/owtsne.py index 32171bbcb46..8e7d4f455d3 100644 --- a/Orange/widgets/unsupervised/owtsne.py +++ b/Orange/widgets/unsupervised/owtsne.py @@ -6,8 +6,8 @@ from AnyQt.QtWidgets import QFormLayout from Orange.data import Table, Domain -from Orange.preprocess.preprocess import Preprocess, ApplyDomain -from Orange.projection import PCA, TSNE, TruncatedSVD +from Orange.preprocess import preprocess +from Orange.projection import PCA, TSNE from Orange.projection.manifold import TSNEModel from Orange.widgets import gui from Orange.widgets.settings import Setting, SettingProvider @@ -76,6 +76,7 @@ class OWtSNE(OWDataProjectionWidget): multiscale = Setting(True) exaggeration = Setting(1) pca_components = Setting(20) + normalize = Setting(True) GRAPH_CLASS = OWtSNEGraph graph = SettingProvider(OWtSNEGraph) @@ -85,7 +86,7 @@ class OWtSNE(OWDataProjectionWidget): Running, Finished, Waiting, Paused = 1, 2, 3, 4 class Outputs(OWDataProjectionWidget.Outputs): - preprocessor = Output("Preprocessor", Preprocess) + preprocessor = Output("Preprocessor", preprocess.Preprocess) class Error(OWDataProjectionWidget.Error): not_enough_rows = Msg("Input data needs at least 2 rows") @@ -143,15 +144,25 @@ def _add_controls_start_box(self): sbp = gui.hBox(self.controlArea, False, addToLayout=False) gui.hSlider( sbp, self, "pca_components", minValue=2, maxValue=50, step=1, - callback=self._params_changed + callback=self._invalidate_pca_projection ) form.addRow("PCA components:", sbp) + self.normalize_cbx = gui.checkBox( + box, self, "normalize", "Normalize data", + callback=self._invalidate_pca_projection, + ) + form.addRow(self.normalize_cbx) + box.layout().addLayout(form) gui.separator(box, 10) self.runbutton = gui.button(box, self, "Run", callback=self._toggle_run) + def _invalidate_pca_projection(self): + self.pca_data = None + self._params_changed() + def _params_changed(self): self.__state = OWtSNE.Finished self.__set_update_loop(None) @@ -215,12 +226,32 @@ def stop(self): def resume(self): self.__set_update_loop(self.tsne_iterator) + def set_data(self, data: Table): + super().set_data(data) + + if data is not None: + # PCA doesn't support normalization on sparse data, as this would + # require centering and normalizing the matrix + self.normalize_cbx.setDisabled(data.is_sparse()) + if data.is_sparse(): + self.normalize = False + self.normalize_cbx.setToolTip( + "Data normalization is not supported on sparse matrices." + ) + else: + self.normalize_cbx.setToolTip("") + def pca_preprocessing(self): - if self.pca_data is not None and \ - self.pca_data.X.shape[1] == self.pca_components: + """Perform PCA preprocessing before passing off the data to t-SNE.""" + if self.pca_data is not None: return - cls = TruncatedSVD if self.data.is_sparse() else PCA - projector = cls(n_components=self.pca_components, random_state=0) + + projector = PCA(n_components=self.pca_components, random_state=0) + # If the normalization box is ticked, we'll add the `Normalize` + # preprocessor to PCA + if self.normalize: + projector.preprocessors += (preprocess.Normalize(),) + model = projector(self.data) self.pca_data = model(self.data) @@ -343,7 +374,7 @@ def _get_projection_data(self): def send_preprocessor(self): prep = None if self.data is not None and self.projection is not None: - prep = ApplyDomain(self.projection.domain, self.projection.name) + prep = preprocess.ApplyDomain(self.projection.domain, self.projection.name) self.Outputs.preprocessor.send(prep) def clear(self): diff --git a/Orange/widgets/unsupervised/tests/test_owtsne.py b/Orange/widgets/unsupervised/tests/test_owtsne.py index 1cfede067cf..d79c2ee7b35 100644 --- a/Orange/widgets/unsupervised/tests/test_owtsne.py +++ b/Orange/widgets/unsupervised/tests/test_owtsne.py @@ -1,10 +1,9 @@ import unittest +from unittest.mock import patch import numpy as np -from AnyQt.QtTest import QSignalSpy - from Orange.data import DiscreteVariable, ContinuousVariable, Domain, Table -from Orange.preprocess import Preprocess +from Orange.preprocess import Preprocess, Normalize from Orange.projection.manifold import TSNE from Orange.widgets.tests.base import ( WidgetTest, WidgetOutputsTestMixin, ProjectionWidgetTestMixin @@ -50,9 +49,9 @@ def optimize(*_, **__): self.empty_domain = Domain([], class_vars=self.class_var) def tearDown(self): - self.reset_tsne() + self.restore_mocked_functions() - def reset_tsne(self): + def restore_mocked_functions(self): owtsne.TSNE.fit = self._fit owtsne.TSNEModel.transform = self._transform owtsne.TSNEModel.optimize = self._optimize @@ -113,21 +112,26 @@ def test_attr_models(self): self.assertIn(var, controls.attr_shape.model()) def test_output_preprocessor(self): - self.reset_tsne() + # To test the validity of the preprocessor, we'll have to actually + # compute the projections + self.restore_mocked_functions() + self.send_signal(self.widget.Inputs.data, self.data) - if self.widget.isBlocking(): - spy = QSignalSpy(self.widget.blockingStateChanged) - self.assertTrue(spy.wait(20000)) + self.wait_until_stop_blocking(wait=20000) + output_data = self.get_output(self.widget.Outputs.annotated_data) + + # We send the same data to the widget, we expect the point locations to + # be fairly close to their original ones pp = self.get_output(self.widget.Outputs.preprocessor) self.assertIsInstance(pp, Preprocess) - transformed = pp(self.data) - self.assertIsInstance(transformed, Table) - self.assertEqual(transformed.X.shape, (len(self.data), 2)) - output = self.get_output(self.widget.Outputs.annotated_data) - np.testing.assert_allclose(transformed.X, output.metas[:, :2], - rtol=1, atol=1) - self.assertEqual([a.name for a in transformed.domain.attributes], - [m.name for m in output.domain.metas[:2]]) + + transformed_data = pp(self.data) + self.assertIsInstance(transformed_data, Table) + self.assertEqual(transformed_data.X.shape, (len(self.data), 2)) + np.testing.assert_allclose(transformed_data.X, output_data.metas[:, :2], + rtol=1, atol=3) + self.assertEqual([a.name for a in transformed_data.domain.attributes], + [m.name for m in output_data.domain.metas[:2]]) def test_multiscale_changed(self): self.assertFalse(self.widget.controls.multiscale.isChecked()) @@ -140,6 +144,32 @@ def test_multiscale_changed(self): self.assertTrue(w.controls.multiscale.isChecked()) self.assertFalse(w.perplexity_spin.isEnabled()) + def test_normalize_data(self): + # Normalization should be checked by default + self.assertTrue(self.widget.controls.normalize.isChecked()) + with patch("Orange.preprocess.preprocess.Normalize", wraps=Normalize) as normalize: + self.send_signal(self.widget.Inputs.data, self.data) + self.assertTrue(self.widget.controls.normalize.isEnabled()) + normalize.assert_called_once() + + # Disable checkbox + self.widget.controls.normalize.setChecked(False) + self.assertFalse(self.widget.controls.normalize.isChecked()) + with patch("Orange.preprocess.preprocess.Normalize", wraps=Normalize) as normalize: + self.send_signal(self.widget.Inputs.data, self.data) + self.assertTrue(self.widget.controls.normalize.isEnabled()) + normalize.assert_not_called() + + # Normalization shouldn't work on sparse data + self.widget.controls.normalize.setChecked(True) + self.assertTrue(self.widget.controls.normalize.isChecked()) + + sparse_data = self.data.to_sparse() + with patch("Orange.preprocess.preprocess.Normalize", wraps=Normalize) as normalize: + self.send_signal(self.widget.Inputs.data, sparse_data) + self.assertFalse(self.widget.controls.normalize.isEnabled()) + normalize.assert_not_called() + if __name__ == '__main__': unittest.main()