Skip to content

Commit 5221d78

Browse files
OwTSNE: Add Normalize data checkbox
1 parent 3fae01f commit 5221d78

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

Orange/widgets/unsupervised/owtsne.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from AnyQt.QtWidgets import QFormLayout
77

88
from Orange.data import Table, Domain
9-
from Orange.preprocess.preprocess import Preprocess, ApplyDomain
9+
from Orange.preprocess.preprocess import Preprocess, ApplyDomain, Normalize
1010
from Orange.projection import PCA, TSNE
1111
from Orange.projection.manifold import TSNEModel
1212
from Orange.widgets import gui
@@ -76,6 +76,7 @@ class OWtSNE(OWDataProjectionWidget):
7676
multiscale = Setting(True)
7777
exaggeration = Setting(1)
7878
pca_components = Setting(20)
79+
normalize = Setting(True)
7980

8081
GRAPH_CLASS = OWtSNEGraph
8182
graph = SettingProvider(OWtSNEGraph)
@@ -143,15 +144,25 @@ def _add_controls_start_box(self):
143144
sbp = gui.hBox(self.controlArea, False, addToLayout=False)
144145
gui.hSlider(
145146
sbp, self, "pca_components", minValue=2, maxValue=50, step=1,
146-
callback=self._params_changed
147+
callback=self._invalidate_pca_projection
147148
)
148149
form.addRow("PCA components:", sbp)
149150

151+
self.normalize_cbx = gui.checkBox(
152+
box, self, "normalize", "Normalize data",
153+
callback=self._invalidate_pca_projection,
154+
)
155+
form.addRow(self.normalize_cbx)
156+
150157
box.layout().addLayout(form)
151158

152159
gui.separator(box, 10)
153160
self.runbutton = gui.button(box, self, "Run", callback=self._toggle_run)
154161

162+
def _invalidate_pca_projection(self):
163+
self.pca_data = None
164+
self._params_changed()
165+
155166
def _params_changed(self):
156167
self.__state = OWtSNE.Finished
157168
self.__set_update_loop(None)
@@ -215,10 +226,25 @@ def stop(self):
215226
def resume(self):
216227
self.__set_update_loop(self.tsne_iterator)
217228

229+
def set_data(self, data: Table):
230+
super().set_data(data)
231+
232+
if data is not None:
233+
# PCA doesn't support normalization on sparse data, as this would
234+
# require centering and normalizing the matrix
235+
self.normalize_cbx.setDisabled(data.is_sparse())
236+
218237
def pca_preprocessing(self):
219-
if self.pca_data is not None and self.pca_data.X.shape[1] == self.pca_components:
238+
"""Perform PCA preprocessing before passing off the data to t-SNE."""
239+
if self.pca_data is not None:
220240
return
241+
221242
projector = PCA(n_components=self.pca_components, random_state=0)
243+
# If the normalization box is ticked, we'll add the `Normalize`
244+
# preprocessor to PCA
245+
if self.normalize:
246+
projector.preprocessors += (Normalize(),)
247+
222248
model = projector(self.data)
223249
self.pca_data = model(self.data)
224250

Orange/widgets/unsupervised/tests/test_owtsne.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import unittest
22
import numpy as np
33

4-
from AnyQt.QtTest import QSignalSpy
5-
64
from Orange.data import DiscreteVariable, ContinuousVariable, Domain, Table
75
from Orange.preprocess import Preprocess
86
from Orange.projection.manifold import TSNE
@@ -50,9 +48,9 @@ def optimize(*_, **__):
5048
self.empty_domain = Domain([], class_vars=self.class_var)
5149

5250
def tearDown(self):
53-
self.reset_tsne()
51+
self.restore_mocked_functions()
5452

55-
def reset_tsne(self):
53+
def restore_mocked_functions(self):
5654
owtsne.TSNE.fit = self._fit
5755
owtsne.TSNEModel.transform = self._transform
5856
owtsne.TSNEModel.optimize = self._optimize
@@ -113,21 +111,26 @@ def test_attr_models(self):
113111
self.assertIn(var, controls.attr_shape.model())
114112

115113
def test_output_preprocessor(self):
116-
self.reset_tsne()
114+
# To test the validity of the preprocessor, we'll have to actually
115+
# compute the projections
116+
self.restore_mocked_functions()
117+
117118
self.send_signal(self.widget.Inputs.data, self.data)
118-
if self.widget.isBlocking():
119-
spy = QSignalSpy(self.widget.blockingStateChanged)
120-
self.assertTrue(spy.wait(20000))
119+
self.wait_until_stop_blocking(wait=20000)
120+
output_data = self.get_output(self.widget.Outputs.annotated_data)
121+
122+
# We send the same data to the widget, we expect the point locations to
123+
# be fairly close to their original ones
121124
pp = self.get_output(self.widget.Outputs.preprocessor)
122125
self.assertIsInstance(pp, Preprocess)
123-
transformed = pp(self.data)
124-
self.assertIsInstance(transformed, Table)
125-
self.assertEqual(transformed.X.shape, (len(self.data), 2))
126-
output = self.get_output(self.widget.Outputs.annotated_data)
127-
np.testing.assert_allclose(transformed.X, output.metas[:, :2],
128-
rtol=1, atol=1)
129-
self.assertEqual([a.name for a in transformed.domain.attributes],
130-
[m.name for m in output.domain.metas[:2]])
126+
127+
transformed_data = pp(self.data)
128+
self.assertIsInstance(transformed_data, Table)
129+
self.assertEqual(transformed_data.X.shape, (len(self.data), 2))
130+
np.testing.assert_allclose(transformed_data.X, output_data.metas[:, :2],
131+
rtol=1, atol=1.5)
132+
self.assertEqual([a.name for a in transformed_data.domain.attributes],
133+
[m.name for m in output_data.domain.metas[:2]])
131134

132135
def test_multiscale_changed(self):
133136
self.assertFalse(self.widget.controls.multiscale.isChecked())

0 commit comments

Comments
 (0)