Skip to content

Commit a1f701a

Browse files
authored
Merge pull request #3475 from VesnaT/tsne4
[ENH] t-SNE: Updates 2.
2 parents 0f189a5 + 8d884fe commit a1f701a

File tree

3 files changed

+55
-22
lines changed

3 files changed

+55
-22
lines changed

Orange/projection/manifold.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,8 @@ def __call__(self, data: Table) -> TSNEModel:
483483
model.name = self.name
484484

485485
return model
486+
487+
@staticmethod
488+
def default_initialization(data, n_components=2, random_state=None):
489+
return fastTSNE.initialization.pca(
490+
data, n_components, random_state=random_state)

Orange/widgets/unsupervised/owtsne.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from AnyQt.QtCore import Qt, QTimer
44
from AnyQt.QtWidgets import QFormLayout
55

6-
import fastTSNE.initialization
7-
86
from Orange.data import Table, Domain
97
from Orange.preprocess.preprocess import Preprocess, ApplyDomain
108
from Orange.projection import PCA, TSNE, TruncatedSVD
@@ -82,7 +80,7 @@ class OWtSNE(OWDataProjectionWidget):
8280
embedding_variables_names = ("t-SNE-x", "t-SNE-y")
8381

8482
#: Runtime state
85-
Running, Finished, Waiting = 1, 2, 3
83+
Running, Finished, Waiting, Paused = 1, 2, 3, 4
8684

8785
class Outputs(OWDataProjectionWidget.Outputs):
8886
preprocessor = Output("Preprocessor", Preprocess)
@@ -100,6 +98,7 @@ def __init__(self):
10098
self.pca_data = None
10199
self.projection = None
102100
self.tsne_runner = None
101+
self.tsne_iterator = None
103102
self.__update_loop = None
104103
# timer for scheduling updates
105104
self.__timer = QTimer(self, singleShot=True, interval=1,
@@ -122,31 +121,42 @@ def _add_controls_start_box(self):
122121
)
123122

124123
self.perplexity_spin = gui.spin(
125-
box, self, "perplexity", 1, 500, step=1, alignment=Qt.AlignRight)
124+
box, self, "perplexity", 1, 500, step=1, alignment=Qt.AlignRight,
125+
callback=self._params_changed
126+
)
126127
form.addRow("Perplexity:", self.perplexity_spin)
128+
self.perplexity_spin.setEnabled(not self.multiscale)
127129
form.addRow(gui.checkBox(
128130
box, self, "multiscale", label="Preserve global structure",
129131
callback=self._multiscale_changed
130132
))
131-
self._multiscale_changed()
132133

133134
sbe = gui.hBox(self.controlArea, False, addToLayout=False)
134135
gui.hSlider(
135-
sbe, self, "exaggeration", minValue=1, maxValue=4, step=1)
136+
sbe, self, "exaggeration", minValue=1, maxValue=4, step=1,
137+
callback=self._params_changed
138+
)
136139
form.addRow("Exaggeration:", sbe)
137140

138141
sbp = gui.hBox(self.controlArea, False, addToLayout=False)
139142
gui.hSlider(
140-
sbp, self, "pca_components", minValue=2, maxValue=50, step=1)
143+
sbp, self, "pca_components", minValue=2, maxValue=50, step=1,
144+
callback=self._params_changed
145+
)
141146
form.addRow("PCA components:", sbp)
142147

143148
box.layout().addLayout(form)
144149

145150
gui.separator(box, 10)
146151
self.runbutton = gui.button(box, self, "Run", callback=self._toggle_run)
147152

153+
def _params_changed(self):
154+
self.__state = OWtSNE.Finished
155+
self.__set_update_loop(None)
156+
148157
def _multiscale_changed(self):
149158
self.perplexity_spin.setEnabled(not self.multiscale)
159+
self._params_changed()
150160

151161
def check_data(self):
152162
def error(err):
@@ -181,6 +191,8 @@ def _toggle_run(self):
181191
if self.__state == OWtSNE.Running:
182192
self.stop()
183193
self.commit()
194+
elif self.__state == OWtSNE.Paused:
195+
self.resume()
184196
else:
185197
self.start()
186198

@@ -191,8 +203,11 @@ def start(self):
191203
self.__start()
192204

193205
def stop(self):
194-
if self.__state == OWtSNE.Running:
195-
self.__set_update_loop(None)
206+
self.__state = OWtSNE.Paused
207+
self.__set_update_loop(None)
208+
209+
def resume(self):
210+
self.__set_update_loop(self.tsne_iterator)
196211

197212
def pca_preprocessing(self):
198213
if self.pca_data is not None and \
@@ -208,7 +223,7 @@ def __start(self):
208223

209224
# We call PCA through fastTSNE because it involves scaling. Instead of
210225
# worrying about this ourselves, we'll let the library worry for us.
211-
initialization = fastTSNE.initialization.pca(
226+
initialization = TSNE.default_initialization(
212227
self.pca_data.X, n_components=2, random_state=0)
213228

214229
# Compute perplexity settings for multiscale
@@ -233,13 +248,14 @@ def __start(self):
233248
)(self.pca_data)
234249

235250
self.tsne_runner = TSNERunner(self.projection, step_size=50)
236-
237-
self.__set_update_loop(self.tsne_runner.run_optimization())
251+
self.tsne_iterator = self.tsne_runner.run_optimization()
252+
self.__set_update_loop(self.tsne_iterator)
238253
self.progressBarInit(processEvents=None)
239254

240255
def __set_update_loop(self, loop):
241256
if self.__update_loop is not None:
242-
self.__update_loop.close()
257+
if self.__state in (OWtSNE.Finished, OWtSNE.Waiting):
258+
self.__update_loop.close()
243259
self.__update_loop = None
244260
self.progressBarFinished(processEvents=None)
245261

@@ -255,8 +271,10 @@ def __set_update_loop(self, loop):
255271
else:
256272
self.setBlocking(False)
257273
self.setStatusMessage("")
258-
self.runbutton.setText("Start")
259-
self.__state = OWtSNE.Finished
274+
if self.__state in (OWtSNE.Finished, OWtSNE.Waiting):
275+
self.runbutton.setText("Start")
276+
if self.__state == OWtSNE.Paused:
277+
self.runbutton.setText("Resume")
260278
self.__timer.stop()
261279

262280
def __next_step(self):
@@ -273,13 +291,16 @@ def __next_step(self):
273291
projection, progress = next(self.__update_loop)
274292
assert self.__update_loop is loop
275293
except StopIteration:
294+
self.__state = OWtSNE.Finished
276295
self.__set_update_loop(None)
277296
self.unconditional_commit()
278297
except MemoryError:
279298
self.Error.out_of_memory()
299+
self.__state = OWtSNE.Finished
280300
self.__set_update_loop(None)
281301
except Exception as exc:
282302
self.Error.optimization_error(str(exc))
303+
self.__state = OWtSNE.Finished
283304
self.__set_update_loop(None)
284305
else:
285306
self.progressBarSet(100.0 * progress, processEvents=None)
@@ -321,8 +342,8 @@ def send_preprocessor(self):
321342

322343
def clear(self):
323344
super().clear()
324-
self.__set_update_loop(None)
325345
self.__state = OWtSNE.Waiting
346+
self.__set_update_loop(None)
326347
self.pca_data = None
327348
self.projection = None
328349

Orange/widgets/unsupervised/tests/test_owtsne.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import unittest
22
import numpy as np
33

4+
from AnyQt.QtTest import QSignalSpy
5+
46
from Orange.data import DiscreteVariable, ContinuousVariable, Domain, Table
57
from Orange.preprocess import Preprocess
68
from Orange.projection.manifold import TSNE
@@ -39,7 +41,8 @@ def optimize(*_, **__):
3941
owtsne.TSNEModel.transform = transform
4042
owtsne.TSNEModel.optimize = optimize
4143

42-
self.widget = self.create_widget(OWtSNE)
44+
self.widget = self.create_widget(OWtSNE,
45+
stored_settings={"multiscale": False})
4346

4447
self.class_var = DiscreteVariable('Stage name', values=['STG1', 'STG2'])
4548
self.attributes = [ContinuousVariable('GeneName' + str(i)) for i in range(5)]
@@ -110,7 +113,11 @@ def test_attr_models(self):
110113
self.assertIn(var, controls.attr_shape.model())
111114

112115
def test_output_preprocessor(self):
116+
self.reset_tsne()
113117
self.send_signal(self.widget.Inputs.data, self.data)
118+
if self.widget.isBlocking():
119+
spy = QSignalSpy(self.widget.blockingStateChanged)
120+
self.assertTrue(spy.wait(20000))
114121
pp = self.get_output(self.widget.Outputs.preprocessor)
115122
self.assertIsInstance(pp, Preprocess)
116123
transformed = pp(self.data)
@@ -123,15 +130,15 @@ def test_output_preprocessor(self):
123130
[m.name for m in output.domain.metas[:2]])
124131

125132
def test_multiscale_changed(self):
126-
self.assertTrue(self.widget.controls.multiscale.isChecked())
127-
self.assertFalse(self.widget.perplexity_spin.isEnabled())
128-
self.widget.controls.multiscale.setChecked(False)
133+
self.assertFalse(self.widget.controls.multiscale.isChecked())
129134
self.assertTrue(self.widget.perplexity_spin.isEnabled())
135+
self.widget.controls.multiscale.setChecked(True)
136+
self.assertFalse(self.widget.perplexity_spin.isEnabled())
130137

131138
settings = self.widget.settingsHandler.pack_data(self.widget)
132139
w = self.create_widget(OWtSNE, stored_settings=settings)
133-
self.assertFalse(w.controls.multiscale.isChecked())
134-
self.assertTrue(w.perplexity_spin.isEnabled())
140+
self.assertTrue(w.controls.multiscale.isChecked())
141+
self.assertFalse(w.perplexity_spin.isEnabled())
135142

136143

137144
if __name__ == '__main__':

0 commit comments

Comments
 (0)