Skip to content

Commit 07b48c8

Browse files
lanzagarpavlin-policar
authored andcommitted
Merge pull request biolab#3592 from pavlin-policar/tsne-faster
[FIX] t-SNE speed-ups
2 parents dd6093d + 914588c commit 07b48c8

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

Orange/widgets/unsupervised/owtsne.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818

1919

2020
class TSNERunner:
21-
def __init__(self, tsne: TSNEModel, step_size=50):
21+
def __init__(self, tsne: TSNEModel, step_size=50, exaggeration=1):
2222
self.embedding = tsne
2323
self.iterations_done = 0
2424
self.step_size = step_size
25+
self.exaggeration = exaggeration
2526

2627
# Larger data sets need a larger number of iterations
2728
if self.n_samples > 100_000:
@@ -43,7 +44,7 @@ def run_optimization(self):
4344
while not current_iter >= total_iterations:
4445
# Switch to normal regime if early exaggeration phase is over
4546
if current_iter >= self.early_exagg_iter:
46-
exaggeration, momentum = 1, 0.8
47+
exaggeration, momentum = self.exaggeration, 0.8
4748

4849
# Resume optimization for some number of steps
4950
self.embedding.optimize(
@@ -73,7 +74,7 @@ class OWtSNE(OWDataProjectionWidget):
7374
settings_version = 3
7475
max_iter = Setting(300)
7576
perplexity = Setting(30)
76-
multiscale = Setting(True)
77+
multiscale = Setting(False)
7778
exaggeration = Setting(1)
7879
pca_components = Setting(20)
7980
normalize = Setting(True)
@@ -110,6 +111,13 @@ def __init__(self):
110111
self.__in_next_step = False
111112
self.__draw_similar_pairs = False
112113

114+
def reset_needs_to_draw():
115+
self.needs_to_draw = True
116+
117+
self.needs_to_draw = True
118+
self.__timer_draw = QTimer(self, interval=2000,
119+
timeout=reset_needs_to_draw)
120+
113121
def _add_controls(self):
114122
self._add_controls_start_box()
115123
super()._add_controls()
@@ -258,6 +266,8 @@ def pca_preprocessing(self):
258266
def __start(self):
259267
self.pca_preprocessing()
260268

269+
self.needs_to_draw = True
270+
261271
# We call PCA through fastTSNE because it involves scaling. Instead of
262272
# worrying about this ourselves, we'll let the library worry for us.
263273
initialization = TSNE.default_initialization(
@@ -281,10 +291,13 @@ def __start(self):
281291
n_components=2, perplexity=perplexity, multiscale=self.multiscale,
282292
early_exaggeration_iter=0, n_iter=0, initialization=initialization,
283293
exaggeration=self.exaggeration, neighbors=neighbor_method,
284-
negative_gradient_method=gradient_method, random_state=0
294+
negative_gradient_method=gradient_method, random_state=0,
295+
theta=0.8,
285296
)(self.pca_data)
286297

287-
self.tsne_runner = TSNERunner(self.projection, step_size=50)
298+
self.tsne_runner = TSNERunner(
299+
self.projection, step_size=20, exaggeration=self.exaggeration
300+
)
288301
self.tsne_iterator = self.tsne_runner.run_optimization()
289302
self.__set_update_loop(self.tsne_iterator)
290303
self.progressBarInit(processEvents=None)
@@ -305,6 +318,7 @@ def __set_update_loop(self, loop):
305318
self.runbutton.setText("Stop")
306319
self.__state = OWtSNE.Running
307320
self.__timer.start()
321+
self.__timer_draw.start()
308322
else:
309323
self.setBlocking(False)
310324
self.setStatusMessage("")
@@ -313,6 +327,7 @@ def __set_update_loop(self, loop):
313327
if self.__state == OWtSNE.Paused:
314328
self.runbutton.setText("Resume")
315329
self.__timer.stop()
330+
self.__timer_draw.stop()
316331

317332
def __next_step(self):
318333
if self.__update_loop is None:
@@ -342,8 +357,10 @@ def __next_step(self):
342357
else:
343358
self.progressBarSet(100.0 * progress, processEvents=None)
344359
self.projection = projection
345-
self.graph.update_coordinates()
346-
self.graph.update_density()
360+
if progress == 1 or self.needs_to_draw:
361+
self.graph.update_coordinates()
362+
self.graph.update_density()
363+
self.needs_to_draw = False
347364
# schedule next update
348365
self.__timer.start()
349366

@@ -360,15 +377,16 @@ def commit(self):
360377
def _get_projection_data(self):
361378
if self.data is None:
362379
return None
363-
if self.projection is None:
364-
variables = self._get_projection_variables()
365-
else:
366-
variables = self.projection.domain.attributes
367380
data = self.data.transform(
368381
Domain(self.data.domain.attributes,
369382
self.data.domain.class_vars,
370-
self.data.domain.metas + variables))
383+
self.data.domain.metas + self._get_projection_variables()))
371384
data.metas[:, -2:] = self.get_embedding()
385+
if self.projection is not None:
386+
data.domain = Domain(
387+
self.data.domain.attributes,
388+
self.data.domain.class_vars,
389+
self.data.domain.metas + self.projection.domain.attributes)
372390
return data
373391

374392
def send_preprocessor(self):

Orange/widgets/unsupervised/tests/test_owtsne.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,31 @@ def test_normalize_data(self):
170170
self.assertFalse(self.widget.controls.normalize.isEnabled())
171171
normalize.assert_not_called()
172172

173+
@patch("Orange.projection.manifold.TSNEModel.optimize")
174+
def test_exaggeration_is_passed_through_properly(self, optimize):
175+
def _check_exaggeration(call, exaggeration):
176+
# Check the last call to `optimize`, so we catch one during the
177+
# regular regime
178+
name, args, kwargs = call.mock_calls[-1]
179+
self.assertIn("exaggeration", kwargs)
180+
self.assertEqual(kwargs["exaggeration"], exaggeration)
181+
182+
# Set value to 1
183+
self.widget.controls.exaggeration.setValue(1)
184+
self.send_signal(self.widget.Inputs.data, self.data)
185+
self.commit_and_wait()
186+
_check_exaggeration(optimize, 1)
187+
188+
# Reset and clear state
189+
optimize.reset_mock()
190+
self.send_signal(self.widget.Inputs.data, None)
191+
192+
# Change to 3
193+
self.widget.controls.exaggeration.setValue(3)
194+
self.send_signal(self.widget.Inputs.data, self.data)
195+
self.commit_and_wait()
196+
_check_exaggeration(optimize, 3)
197+
173198

174199
if __name__ == '__main__':
175200
unittest.main()

0 commit comments

Comments
 (0)