Skip to content

Commit e628b78

Browse files
OWTSNE: Offload computation to separate thread
1 parent 10269b5 commit e628b78

File tree

4 files changed

+644
-295
lines changed

4 files changed

+644
-295
lines changed

Orange/projection/manifold.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,10 @@ def optimize(self, n_iter, inplace=False, propagate_exception=False, **kwargs):
265265
new_embedding = self.embedding_.optimize(**kwargs)
266266
table = Table(self.embedding.domain, new_embedding.view(np.ndarray),
267267
self.embedding.Y, self.embedding.metas)
268-
return TSNEModel(new_embedding, table, self.pre_domain)
268+
269+
new_model = TSNEModel(new_embedding, table, self.pre_domain)
270+
new_model.name = self.name
271+
return new_model
269272

270273

271274
class TSNE(Projector):
@@ -400,7 +403,7 @@ def __init__(self, n_components=2, perplexity=30, learning_rate=200,
400403
self.callbacks_every_iters = callbacks_every_iters
401404
self.random_state = random_state
402405

403-
def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
406+
def compute_affinities(self, X):
404407
# Sparse data are not supported
405408
if sp.issparse(X):
406409
raise TypeError(
@@ -415,41 +418,75 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
415418
if not isinstance(self.perplexity, Iterable):
416419
raise ValueError(
417420
"Perplexity should be an instance of `Iterable`, `%s` "
418-
"given." % type(self.perplexity).__name__)
421+
"given." % type(self.perplexity).__name__
422+
)
419423
affinities = openTSNE.affinity.Multiscale(
420-
X, perplexities=self.perplexity, metric=self.metric,
421-
method=self.neighbors, random_state=self.random_state, n_jobs=self.n_jobs)
424+
X,
425+
perplexities=self.perplexity,
426+
metric=self.metric,
427+
method=self.neighbors,
428+
random_state=self.random_state,
429+
n_jobs=self.n_jobs,
430+
)
422431
else:
423432
if isinstance(self.perplexity, Iterable):
424433
raise ValueError(
425434
"Perplexity should be an instance of `float`, `%s` "
426-
"given." % type(self.perplexity).__name__)
435+
"given." % type(self.perplexity).__name__
436+
)
427437
affinities = openTSNE.affinity.PerplexityBasedNN(
428-
X, perplexity=self.perplexity, metric=self.metric,
429-
method=self.neighbors, random_state=self.random_state, n_jobs=self.n_jobs)
438+
X,
439+
perplexity=self.perplexity,
440+
metric=self.metric,
441+
method=self.neighbors,
442+
random_state=self.random_state,
443+
n_jobs=self.n_jobs,
444+
)
430445

431-
# Create an initial embedding
446+
return affinities
447+
448+
def compute_initialization(self, X):
449+
# Compute the initial positions of individual points
432450
if isinstance(self.initialization, np.ndarray):
433451
initialization = self.initialization
434452
elif self.initialization == "pca":
435453
initialization = openTSNE.initialization.pca(
436-
X, self.n_components, random_state=self.random_state)
454+
X, self.n_components, random_state=self.random_state
455+
)
437456
elif self.initialization == "random":
438457
initialization = openTSNE.initialization.random(
439-
X, self.n_components, random_state=self.random_state)
458+
X, self.n_components, random_state=self.random_state
459+
)
440460
else:
441461
raise ValueError(
442462
"Invalid initialization `%s`. Please use either `pca` or "
443-
"`random` or provide a numpy array." % self.initialization)
463+
"`random` or provide a numpy array." % self.initialization
464+
)
444465

445-
embedding = openTSNE.TSNEEmbedding(
446-
initialization, affinities, learning_rate=self.learning_rate,
447-
theta=self.theta, min_num_intervals=self.min_num_intervals,
448-
ints_in_interval=self.ints_in_interval, n_jobs=self.n_jobs,
466+
return initialization
467+
468+
def prepare_embedding(self, affinities, initialization):
469+
"""Prepare an embedding object with appropriate parameters, given some
470+
affinities and initialization."""
471+
return openTSNE.TSNEEmbedding(
472+
initialization,
473+
affinities,
474+
learning_rate=self.learning_rate,
475+
theta=self.theta,
476+
min_num_intervals=self.min_num_intervals,
477+
ints_in_interval=self.ints_in_interval,
478+
n_jobs=self.n_jobs,
449479
negative_gradient_method=self.negative_gradient_method,
450-
callbacks=self.callbacks, callbacks_every_iters=self.callbacks_every_iters,
480+
callbacks=self.callbacks,
481+
callbacks_every_iters=self.callbacks_every_iters,
451482
)
452483

484+
def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
485+
# Compute affinities and initial positions and prepare the embedding object
486+
affinities = self.compute_affinities(X)
487+
initialization = self.compute_initialization(X)
488+
embedding = self.prepare_embedding(affinities, initialization)
489+
453490
# Run standard t-SNE optimization
454491
embedding.optimize(
455492
n_iter=self.early_exaggeration_iter, exaggeration=self.early_exaggeration,
@@ -462,13 +499,7 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
462499

463500
return embedding
464501

465-
def __call__(self, data: Table) -> TSNEModel:
466-
# Preprocess the data - convert discrete to continuous
467-
data = self.preprocess(data)
468-
469-
# Run tSNE optimization
470-
embedding = self.fit(data.X, data.Y)
471-
502+
def convert_embedding_to_model(self, data, embedding):
472503
# The results should be accessible in an Orange table, which doesn't
473504
# need the full embedding attributes and is cast into a regular array
474505
n = self.n_components
@@ -484,6 +515,17 @@ def __call__(self, data: Table) -> TSNEModel:
484515

485516
return model
486517

518+
def __call__(self, data: Table) -> TSNEModel:
519+
# Preprocess the data - convert discrete to continuous
520+
data = self.preprocess(data)
521+
522+
# Run tSNE optimization
523+
embedding = self.fit(data.X, data.Y)
524+
525+
# Convert the t-SNE embedding object to a TSNEModel and prepare the
526+
# embedding table with t-SNE meta variables
527+
return self.convert_embedding_to_model(data, embedding)
528+
487529
@staticmethod
488530
def default_initialization(data, n_components=2, random_state=None):
489531
return openTSNE.initialization.pca(

Orange/widgets/tests/base.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -861,9 +861,11 @@ def test_setup_graph(self, timeout=DEFAULT_TIMEOUT):
861861
"""Plot should exist after data has been sent in order to be
862862
properly set/updated"""
863863
self.send_signal(self.widget.Inputs.data, self.data)
864+
864865
if self.widget.isBlocking():
865866
spy = QSignalSpy(self.widget.blockingStateChanged)
866867
self.assertTrue(spy.wait(timeout))
868+
867869
self.assertIsNotNone(self.widget.graph.scatterplot_item)
868870

869871
def test_default_attrs(self, timeout=DEFAULT_TIMEOUT):
@@ -937,16 +939,21 @@ def test_plot_once(self, timeout=DEFAULT_TIMEOUT):
937939
table = Table("heart_disease")
938940
self.widget.setup_plot = Mock()
939941
self.widget.commit = Mock()
942+
940943
self.send_signal(self.widget.Inputs.data, table)
944+
if self.widget.isBlocking():
945+
spy = QSignalSpy(self.widget.blockingStateChanged)
946+
self.assertTrue(spy.wait(timeout))
947+
941948
self.widget.setup_plot.assert_called_once()
942949
self.widget.commit.assert_called_once()
943950

951+
self.widget.commit.reset_mock()
952+
self.send_signal(self.widget.Inputs.data_subset, table[::10])
944953
if self.widget.isBlocking():
945954
spy = QSignalSpy(self.widget.blockingStateChanged)
946955
self.assertTrue(spy.wait(timeout))
947956

948-
self.widget.commit.reset_mock()
949-
self.send_signal(self.widget.Inputs.data_subset, table[::10])
950957
self.widget.setup_plot.assert_called_once()
951958
self.widget.commit.assert_called_once()
952959

@@ -1003,16 +1010,20 @@ def test_invalidated_embedding(self, timeout=DEFAULT_TIMEOUT):
10031010
self.widget.graph.update_coordinates = Mock()
10041011
self.widget.graph.update_point_props = Mock()
10051012
self.send_signal(self.widget.Inputs.data, self.data)
1006-
self.widget.graph.update_coordinates.assert_called_once()
1007-
self.widget.graph.update_point_props.assert_called_once()
1008-
10091013
if self.widget.isBlocking():
10101014
spy = QSignalSpy(self.widget.blockingStateChanged)
10111015
self.assertTrue(spy.wait(timeout))
10121016

1017+
self.widget.graph.update_coordinates.assert_called()
1018+
self.widget.graph.update_point_props.assert_called()
1019+
10131020
self.widget.graph.update_coordinates.reset_mock()
10141021
self.widget.graph.update_point_props.reset_mock()
10151022
self.send_signal(self.widget.Inputs.data, self.data)
1023+
if self.widget.isBlocking():
1024+
spy = QSignalSpy(self.widget.blockingStateChanged)
1025+
self.assertTrue(spy.wait(timeout))
1026+
10161027
self.widget.graph.update_coordinates.assert_not_called()
10171028
self.widget.graph.update_point_props.assert_called_once()
10181029

@@ -1021,13 +1032,16 @@ def test_saved_selection(self, timeout=DEFAULT_TIMEOUT):
10211032
if self.widget.isBlocking():
10221033
spy = QSignalSpy(self.widget.blockingStateChanged)
10231034
self.assertTrue(spy.wait(timeout))
1035+
10241036
self.widget.graph.select_by_indices(list(range(0, len(self.data), 10)))
10251037
settings = self.widget.settingsHandler.pack_data(self.widget)
10261038
w = self.create_widget(self.widget.__class__, stored_settings=settings)
1039+
10271040
self.send_signal(self.widget.Inputs.data, self.data, widget=w)
10281041
if w.isBlocking():
10291042
spy = QSignalSpy(w.blockingStateChanged)
10301043
self.assertTrue(spy.wait(timeout))
1044+
10311045
self.assertEqual(np.sum(w.graph.selection), 15)
10321046
np.testing.assert_equal(self.widget.graph.selection, w.graph.selection)
10331047

0 commit comments

Comments
 (0)