Skip to content

Commit 3fe6160

Browse files
authored
Merge pull request #3604 from pavlin-policar/tsne-threaded
[ENH] OWTSNE: Offload computation to separate thread
2 parents f099ced + c712620 commit 3fe6160

File tree

4 files changed

+812
-306
lines changed

4 files changed

+812
-306
lines changed

Orange/projection/manifold.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,10 @@ def optimize(self, n_iter, inplace=False, propagate_exception=False, **kwargs):
265265
new_embedding = self.embedding_.optimize(**kwargs)
266266
table = Table(self.embedding.domain, new_embedding.view(np.ndarray),
267267
self.embedding.Y, self.embedding.metas)
268-
return TSNEModel(new_embedding, table, self.pre_domain)
268+
269+
new_model = TSNEModel(new_embedding, table, self.pre_domain)
270+
new_model.name = self.name
271+
return new_model
269272

270273

271274
class TSNE(Projector):
@@ -400,7 +403,7 @@ def __init__(self, n_components=2, perplexity=30, learning_rate=200,
400403
self.callbacks_every_iters = callbacks_every_iters
401404
self.random_state = random_state
402405

403-
def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
406+
def compute_affinities(self, X):
404407
# Sparse data are not supported
405408
if sp.issparse(X):
406409
raise TypeError(
@@ -415,41 +418,75 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
415418
if not isinstance(self.perplexity, Iterable):
416419
raise ValueError(
417420
"Perplexity should be an instance of `Iterable`, `%s` "
418-
"given." % type(self.perplexity).__name__)
421+
"given." % type(self.perplexity).__name__
422+
)
419423
affinities = openTSNE.affinity.Multiscale(
420-
X, perplexities=self.perplexity, metric=self.metric,
421-
method=self.neighbors, random_state=self.random_state, n_jobs=self.n_jobs)
424+
X,
425+
perplexities=self.perplexity,
426+
metric=self.metric,
427+
method=self.neighbors,
428+
random_state=self.random_state,
429+
n_jobs=self.n_jobs,
430+
)
422431
else:
423432
if isinstance(self.perplexity, Iterable):
424433
raise ValueError(
425434
"Perplexity should be an instance of `float`, `%s` "
426-
"given." % type(self.perplexity).__name__)
435+
"given." % type(self.perplexity).__name__
436+
)
427437
affinities = openTSNE.affinity.PerplexityBasedNN(
428-
X, perplexity=self.perplexity, metric=self.metric,
429-
method=self.neighbors, random_state=self.random_state, n_jobs=self.n_jobs)
438+
X,
439+
perplexity=self.perplexity,
440+
metric=self.metric,
441+
method=self.neighbors,
442+
random_state=self.random_state,
443+
n_jobs=self.n_jobs,
444+
)
430445

431-
# Create an initial embedding
446+
return affinities
447+
448+
def compute_initialization(self, X):
449+
# Compute the initial positions of individual points
432450
if isinstance(self.initialization, np.ndarray):
433451
initialization = self.initialization
434452
elif self.initialization == "pca":
435453
initialization = openTSNE.initialization.pca(
436-
X, self.n_components, random_state=self.random_state)
454+
X, self.n_components, random_state=self.random_state
455+
)
437456
elif self.initialization == "random":
438457
initialization = openTSNE.initialization.random(
439-
X, self.n_components, random_state=self.random_state)
458+
X, self.n_components, random_state=self.random_state
459+
)
440460
else:
441461
raise ValueError(
442462
"Invalid initialization `%s`. Please use either `pca` or "
443-
"`random` or provide a numpy array." % self.initialization)
463+
"`random` or provide a numpy array." % self.initialization
464+
)
444465

445-
embedding = openTSNE.TSNEEmbedding(
446-
initialization, affinities, learning_rate=self.learning_rate,
447-
theta=self.theta, min_num_intervals=self.min_num_intervals,
448-
ints_in_interval=self.ints_in_interval, n_jobs=self.n_jobs,
466+
return initialization
467+
468+
def prepare_embedding(self, affinities, initialization):
469+
"""Prepare an embedding object with appropriate parameters, given some
470+
affinities and initialization."""
471+
return openTSNE.TSNEEmbedding(
472+
initialization,
473+
affinities,
474+
learning_rate=self.learning_rate,
475+
theta=self.theta,
476+
min_num_intervals=self.min_num_intervals,
477+
ints_in_interval=self.ints_in_interval,
478+
n_jobs=self.n_jobs,
449479
negative_gradient_method=self.negative_gradient_method,
450-
callbacks=self.callbacks, callbacks_every_iters=self.callbacks_every_iters,
480+
callbacks=self.callbacks,
481+
callbacks_every_iters=self.callbacks_every_iters,
451482
)
452483

484+
def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
485+
# Compute affinities and initial positions and prepare the embedding object
486+
affinities = self.compute_affinities(X)
487+
initialization = self.compute_initialization(X)
488+
embedding = self.prepare_embedding(affinities, initialization)
489+
453490
# Run standard t-SNE optimization
454491
embedding.optimize(
455492
n_iter=self.early_exaggeration_iter, exaggeration=self.early_exaggeration,
@@ -462,13 +499,7 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
462499

463500
return embedding
464501

465-
def __call__(self, data: Table) -> TSNEModel:
466-
# Preprocess the data - convert discrete to continuous
467-
data = self.preprocess(data)
468-
469-
# Run tSNE optimization
470-
embedding = self.fit(data.X, data.Y)
471-
502+
def convert_embedding_to_model(self, data, embedding):
472503
# The results should be accessible in an Orange table, which doesn't
473504
# need the full embedding attributes and is cast into a regular array
474505
n = self.n_components
@@ -484,6 +515,17 @@ def __call__(self, data: Table) -> TSNEModel:
484515

485516
return model
486517

518+
def __call__(self, data: Table) -> TSNEModel:
519+
# Preprocess the data - convert discrete to continuous
520+
data = self.preprocess(data)
521+
522+
# Run tSNE optimization
523+
embedding = self.fit(data.X, data.Y)
524+
525+
# Convert the t-SNE embedding object to a TSNEModel and prepare the
526+
# embedding table with t-SNE meta variables
527+
return self.convert_embedding_to_model(data, embedding)
528+
487529
@staticmethod
488530
def default_initialization(data, n_components=2, random_state=None):
489531
return openTSNE.initialization.pca(

Orange/widgets/tests/base.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -890,9 +890,11 @@ def test_setup_graph(self, timeout=DEFAULT_TIMEOUT):
890890
"""Plot should exist after data has been sent in order to be
891891
properly set/updated"""
892892
self.send_signal(self.widget.Inputs.data, self.data)
893+
893894
if self.widget.isBlocking():
894895
spy = QSignalSpy(self.widget.blockingStateChanged)
895896
self.assertTrue(spy.wait(timeout))
897+
896898
self.assertIsNotNone(self.widget.graph.scatterplot_item)
897899

898900
def test_default_attrs(self, timeout=DEFAULT_TIMEOUT):
@@ -967,6 +969,9 @@ def test_plot_once(self, timeout=DEFAULT_TIMEOUT):
967969
self.widget.setup_plot = Mock()
968970
self.widget.commit = Mock()
969971
self.send_signal(self.widget.Inputs.data, table)
972+
self.widget.setup_plot.assert_called_once()
973+
self.widget.commit.assert_called_once()
974+
970975
if self.widget.isBlocking():
971976
spy = QSignalSpy(self.widget.blockingStateChanged)
972977
self.assertTrue(spy.wait(timeout))
@@ -975,9 +980,6 @@ def test_plot_once(self, timeout=DEFAULT_TIMEOUT):
975980

976981
self.widget.commit.reset_mock()
977982
self.send_signal(self.widget.Inputs.data_subset, table[::10])
978-
if self.widget.isBlocking():
979-
spy = QSignalSpy(self.widget.blockingStateChanged)
980-
self.assertTrue(spy.wait(timeout))
981983
self.widget.setup_plot.assert_called_once()
982984
self.widget.commit.assert_called_once()
983985

@@ -1034,16 +1036,20 @@ def test_invalidated_embedding(self, timeout=DEFAULT_TIMEOUT):
10341036
self.widget.graph.update_coordinates = Mock()
10351037
self.widget.graph.update_point_props = Mock()
10361038
self.send_signal(self.widget.Inputs.data, self.data)
1037-
self.widget.graph.update_coordinates.assert_called_once()
1038-
self.widget.graph.update_point_props.assert_called_once()
1039-
10401039
if self.widget.isBlocking():
10411040
spy = QSignalSpy(self.widget.blockingStateChanged)
10421041
self.assertTrue(spy.wait(timeout))
10431042

1043+
self.widget.graph.update_coordinates.assert_called()
1044+
self.widget.graph.update_point_props.assert_called()
1045+
10441046
self.widget.graph.update_coordinates.reset_mock()
10451047
self.widget.graph.update_point_props.reset_mock()
10461048
self.send_signal(self.widget.Inputs.data, self.data)
1049+
if self.widget.isBlocking():
1050+
spy = QSignalSpy(self.widget.blockingStateChanged)
1051+
self.assertTrue(spy.wait(timeout))
1052+
10471053
self.widget.graph.update_coordinates.assert_not_called()
10481054
self.widget.graph.update_point_props.assert_called_once()
10491055

@@ -1052,13 +1058,16 @@ def test_saved_selection(self, timeout=DEFAULT_TIMEOUT):
10521058
if self.widget.isBlocking():
10531059
spy = QSignalSpy(self.widget.blockingStateChanged)
10541060
self.assertTrue(spy.wait(timeout))
1061+
10551062
self.widget.graph.select_by_indices(list(range(0, len(self.data), 10)))
10561063
settings = self.widget.settingsHandler.pack_data(self.widget)
10571064
w = self.create_widget(self.widget.__class__, stored_settings=settings)
1065+
10581066
self.send_signal(self.widget.Inputs.data, self.data, widget=w)
10591067
if w.isBlocking():
10601068
spy = QSignalSpy(w.blockingStateChanged)
10611069
self.assertTrue(spy.wait(timeout))
1070+
10621071
self.assertEqual(np.sum(w.graph.selection), 15)
10631072
np.testing.assert_equal(self.widget.graph.selection, w.graph.selection)
10641073

0 commit comments

Comments
 (0)