Skip to content

Commit f7dbb03

Browse files
OwTSNE: Don't clear computed values at invalidation, set flags instead
1 parent 9517bf5 commit f7dbb03

File tree

2 files changed

+99
-9
lines changed

2 files changed

+99
-9
lines changed

Orange/widgets/unsupervised/owtsne.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,24 @@ def update_coordinates(self):
230230
self.view_box.setAspectLocked(True, 1)
231231

232232

233+
class invalidated:
234+
pca_projection = affinities = tsne_embedding = False
235+
236+
def __set__(self, instance, value):
237+
# `self._invalidate = True` should invalidate everything
238+
self.pca_projection = self.affinities = self.tsne_embedding = value
239+
240+
def __bool__(self):
241+
# If any of the values are invalidated, this should return true
242+
return self.pca_projection or self.affinities or self.tsne_embedding
243+
244+
def __str__(self):
245+
return "%s(%s)" % (self.__class__.__name__, ", ".join(
246+
"=".join([k, str(getattr(self, k))])
247+
for k in ["pca_projection", "affinities", "tsne_embedding"]
248+
))
249+
250+
233251
class OWtSNE(OWDataProjectionWidget, ConcurrentWidgetMixin):
234252
name = "t-SNE"
235253
description = "Two-dimensional data projection with t-SNE."
@@ -250,6 +268,11 @@ class OWtSNE(OWDataProjectionWidget, ConcurrentWidgetMixin):
250268

251269
left_side_scrolling = True
252270

271+
# Use `invalidated` descriptor so we don't break the usage of
272+
# `_invalidated` in `OWDataProjectionWidget`, but still allow finer control
273+
# over which parts of the embedding to invalidate
274+
_invalidated = invalidated()
275+
253276
class Information(OWDataProjectionWidget.Information):
254277
modified = Msg("The parameter settings have been changed. Press "
255278
"\"Start\" to rerun with the new settings.")
@@ -323,21 +346,19 @@ def _multiscale_changed(self):
323346
self._invalidate_affinities()
324347

325348
def _invalidate_pca_projection(self):
326-
self.pca_projection = None
327-
self.initialization = None
349+
self._invalidated.pca_projection = True
328350
self._invalidate_affinities()
329351

330352
def _invalidate_affinities(self):
331-
self.affinities = None
353+
self._invalidated.affinities = True
332354
self._invalidate_tsne_embedding()
333355

334356
def _invalidate_tsne_embedding(self):
335-
self.iterations_done = 0
336-
self.tsne_embedding = None
337-
self._invalidate_output()
357+
self._invalidated.tsne_embedding = True
358+
self._stop_running_task()
338359
self._set_modified(True)
339360

340-
def _invalidate_output(self):
361+
def _stop_running_task(self):
341362
self.cancel()
342363
self.run_button.setText("Start")
343364

@@ -403,8 +424,15 @@ def _toggle_run(self):
403424
else:
404425
self.run()
405426

406-
def set_data(self, data: Table):
407-
super().set_data(data)
427+
def handleNewSignals(self):
428+
# We don't bother with the granular invalidation flags because
429+
# `super().handleNewSignals` will just set all of them to False or will
430+
# do nothing. However, it's important we remember its state because we
431+
# won't call `run` if needed. `run` also relies on the state of
432+
# `_invalidated` to properly set the intermediate values to None
433+
prev_invalidated = bool(self._invalidated)
434+
super().handleNewSignals()
435+
self._invalidated = prev_invalidated
408436

409437
if self._invalidated:
410438
self.run()
@@ -445,7 +473,17 @@ def enable_controls(self):
445473
self.controls.perplexity.setDisabled(self.multiscale)
446474

447475
def run(self):
476+
# Reset invalidated values as indicated by the flags
477+
if self._invalidated.pca_projection:
478+
self.pca_projection = None
479+
if self._invalidated.affinities:
480+
self.affinities = None
481+
if self._invalidated.tsne_embedding:
482+
self.iterations_done = 0
483+
self.tsne_embedding = None
484+
448485
self._set_modified(False)
486+
self._invalidated = False
449487

450488
# When the data is invalid, it is set to `None` and an error is set,
451489
# therefore it would be erroneous to clear the error here

Orange/widgets/unsupervised/tests/test_owtsne.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,58 @@ def test_modified_info_message_behaviour(self):
288288
"The information message was not cleared on no data"
289289
)
290290

291+
def test_invalidation_flow(self):
292+
w = self.widget
293+
# Setup widget: send data to input with global structure "off", then
294+
# set global structure "on" (after the embedding is computed)
295+
w.controls.multiscale.setChecked(False)
296+
self.send_signal(w.Inputs.data, self.data)
297+
self.wait_until_stop_blocking()
298+
self.assertFalse(self.widget.Information.modified.is_shown())
299+
# All the embedding components should computed
300+
self.assertIsNotNone(w.pca_projection)
301+
self.assertIsNotNone(w.affinities)
302+
self.assertIsNotNone(w.tsne_embedding)
303+
# All the invalidation flags should be set to false
304+
self.assertFalse(w._invalidated.pca_projection)
305+
self.assertFalse(w._invalidated.affinities)
306+
self.assertFalse(w._invalidated.tsne_embedding)
307+
308+
# Trigger invalidation
309+
w.controls.multiscale.setChecked(True)
310+
self.assertTrue(self.widget.Information.modified.is_shown())
311+
# Setting `multiscale` to true should set the invalidate flags for
312+
# the affinities and embedding, but not the pca_projection
313+
self.assertFalse(w._invalidated.pca_projection)
314+
self.assertTrue(w._invalidated.affinities)
315+
self.assertTrue(w._invalidated.tsne_embedding)
316+
317+
# The flags should now be set, but the embedding should still be
318+
# available when selecting a subset of data and such
319+
self.assertIsNotNone(w.pca_projection)
320+
self.assertIsNotNone(w.affinities)
321+
self.assertIsNotNone(w.tsne_embedding)
322+
323+
# We should still be able to send a data subset to the input and have
324+
# the points be highlighted
325+
self.send_signal(w.Inputs.data_subset, self.data[:10])
326+
self.wait_until_stop_blocking()
327+
subset = [brush.color().name() == "#46befa" for brush in
328+
w.graph.scatterplot_item.data["brush"][:10]]
329+
other = [brush.color().name() == "#000000" for brush in
330+
w.graph.scatterplot_item.data["brush"][10:]]
331+
self.assertTrue(all(subset))
332+
self.assertTrue(all(other))
333+
334+
# Clear the data subset
335+
self.send_signal(w.Inputs.data_subset, None)
336+
337+
# Run the optimization
338+
self.widget.run_button.clicked.emit()
339+
self.wait_until_stop_blocking()
340+
# All of the inavalidation flags should have been cleared
341+
self.assertFalse(w._invalidated)
342+
291343

292344
class TestTSNERunner(unittest.TestCase):
293345
@classmethod

0 commit comments

Comments
 (0)