Skip to content

Commit db0a98e

Browse files
OwTSNE: Offload work to separate thread
1 parent bd906f7 commit db0a98e

File tree

5 files changed

+701
-285
lines changed

5 files changed

+701
-285
lines changed

Orange/projection/manifold.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -496,13 +496,7 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
496496

497497
return embedding
498498

499-
def __call__(self, data: Table) -> TSNEModel:
500-
# Preprocess the data - convert discrete to continuous
501-
data = self.preprocess(data)
502-
503-
# Run tSNE optimization
504-
embedding = self.fit(data.X, data.Y)
505-
499+
def convert_embedding_to_model(self, data, embedding):
506500
# The results should be accessible in an Orange table, which doesn't
507501
# need the full embedding attributes and is cast into a regular array
508502
n = self.n_components
@@ -518,6 +512,17 @@ def __call__(self, data: Table) -> TSNEModel:
518512

519513
return model
520514

515+
def __call__(self, data: Table) -> TSNEModel:
516+
# Preprocess the data - convert discrete to continuous
517+
data = self.preprocess(data)
518+
519+
# Run tSNE optimization
520+
embedding = self.fit(data.X, data.Y)
521+
522+
# Convert the t-SNE embedding object to a TSNEModel and prepare the
523+
# embedding table with t-SNE meta variables
524+
return self.convert_embedding_to_model(data, embedding)
525+
521526
@staticmethod
522527
def default_initialization(data, n_components=2, random_state=None):
523528
return openTSNE.initialization.pca(

Orange/widgets/tests/base.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -857,10 +857,15 @@ def _compare_selected_annotated_domains(self, selected, annotated):
857857
annotated_vars = annotated.domain.variables
858858
self.assertLessEqual(set(selected_vars), set(annotated_vars))
859859

860-
def test_setup_graph(self):
860+
def test_setup_graph(self, timeout=DEFAULT_TIMEOUT):
861861
"""Plot should exist after data has been sent in order to be
862862
properly set/updated"""
863863
self.send_signal(self.widget.Inputs.data, self.data)
864+
865+
if self.widget.isBlocking():
866+
spy = QSignalSpy(self.widget.blockingStateChanged)
867+
self.assertTrue(spy.wait(timeout))
868+
864869
self.assertIsNotNone(self.widget.graph.scatterplot_item)
865870

866871
def test_default_attrs(self, timeout=DEFAULT_TIMEOUT):
@@ -934,16 +939,21 @@ def test_plot_once(self, timeout=DEFAULT_TIMEOUT):
934939
table = Table("heart_disease")
935940
self.widget.setup_plot = Mock()
936941
self.widget.commit = Mock()
942+
937943
self.send_signal(self.widget.Inputs.data, table)
944+
if self.widget.isBlocking():
945+
spy = QSignalSpy(self.widget.blockingStateChanged)
946+
self.assertTrue(spy.wait(timeout))
947+
938948
self.widget.setup_plot.assert_called_once()
939949
self.widget.commit.assert_called_once()
940950

951+
self.widget.commit.reset_mock()
952+
self.send_signal(self.widget.Inputs.data_subset, table[::10])
941953
if self.widget.isBlocking():
942954
spy = QSignalSpy(self.widget.blockingStateChanged)
943955
self.assertTrue(spy.wait(timeout))
944956

945-
self.widget.commit.reset_mock()
946-
self.send_signal(self.widget.Inputs.data_subset, table[::10])
947957
self.widget.setup_plot.assert_called_once()
948958
self.widget.commit.assert_called_once()
949959

@@ -985,25 +995,38 @@ def test_invalidated_embedding(self, timeout=DEFAULT_TIMEOUT):
985995
self.widget.graph.update_coordinates = Mock()
986996
self.widget.graph.update_point_props = Mock()
987997
self.send_signal(self.widget.Inputs.data, self.data)
988-
self.widget.graph.update_coordinates.assert_called_once()
989-
self.widget.graph.update_point_props.assert_called_once()
990-
991998
if self.widget.isBlocking():
992999
spy = QSignalSpy(self.widget.blockingStateChanged)
9931000
self.assertTrue(spy.wait(timeout))
9941001

1002+
self.widget.graph.update_coordinates.assert_called()
1003+
self.widget.graph.update_point_props.assert_called()
1004+
9951005
self.widget.graph.update_coordinates.reset_mock()
9961006
self.widget.graph.update_point_props.reset_mock()
9971007
self.send_signal(self.widget.Inputs.data, self.data)
1008+
if self.widget.isBlocking():
1009+
spy = QSignalSpy(self.widget.blockingStateChanged)
1010+
self.assertTrue(spy.wait(timeout))
1011+
9981012
self.widget.graph.update_coordinates.assert_not_called()
9991013
self.widget.graph.update_point_props.assert_called_once()
10001014

1001-
def test_saved_selection(self):
1015+
def test_saved_selection(self, timeout=DEFAULT_TIMEOUT):
10021016
self.send_signal(self.widget.Inputs.data, self.data)
1017+
if self.widget.isBlocking():
1018+
spy = QSignalSpy(self.widget.blockingStateChanged)
1019+
self.assertTrue(spy.wait(timeout))
1020+
10031021
self.widget.graph.select_by_indices(list(range(0, len(self.data), 10)))
10041022
settings = self.widget.settingsHandler.pack_data(self.widget)
10051023
w = self.create_widget(self.widget.__class__, stored_settings=settings)
1024+
10061025
self.send_signal(self.widget.Inputs.data, self.data, widget=w)
1026+
if w.isBlocking():
1027+
spy = QSignalSpy(w.blockingStateChanged)
1028+
self.assertTrue(spy.wait(timeout))
1029+
10071030
self.assertEqual(np.sum(w.graph.selection), 15)
10081031
np.testing.assert_equal(self.widget.graph.selection, w.graph.selection)
10091032

0 commit comments

Comments
 (0)