Skip to content

Commit b4cd6c7

Browse files
2 parents c5f04d9 + 53117be commit b4cd6c7

File tree

3 files changed

+50
-32
lines changed

3 files changed

+50
-32
lines changed

micro_sam/sam_annotator/_widgets.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from superqt import QCollapsible
2222
from magicgui import magic_factory
2323
from magicgui.widgets import ComboBox, Container, create_widget
24-
from napari.qt.threading import thread_worker
24+
# We have disabled the thread workers for now because they result in a
25+
# massive slowdown in napari >= 0.5.
26+
# See also https://forum.image.sc/t/napari-thread-worker-leads-to-massive-slowdown/103786
27+
# from napari.qt.threading import thread_worker
2528
from napari.utils import progress
2629

2730
from ._state import AnnotatorState
@@ -1088,7 +1091,7 @@ def __call__(self, skip_validate=False):
10881091
# Set up progress bar and signals for using it within a threadworker.
10891092
pbar, pbar_signals = _create_pbar_for_threadworker()
10901093

1091-
@thread_worker()
1094+
# @thread_worker()
10921095
def compute_image_embedding():
10931096

10941097
def pbar_init(total, description):
@@ -1103,10 +1106,12 @@ def pbar_init(total, description):
11031106
)
11041107
pbar_signals.pbar_stop.emit()
11051108

1106-
worker = compute_image_embedding()
1107-
worker.returned.connect(self._update_model)
1108-
worker.start()
1109-
return worker
1109+
compute_image_embedding()
1110+
self._update_model()
1111+
# worker = compute_image_embedding()
1112+
# worker.returned.connect(self._update_model)
1113+
# worker.start()
1114+
# return worker
11101115

11111116

11121117
#
@@ -1195,7 +1200,7 @@ def _run_tracking(self):
11951200
state = AnnotatorState()
11961201
pbar, pbar_signals = _create_pbar_for_threadworker()
11971202

1198-
@thread_worker
1203+
# @thread_worker
11991204
def tracking_impl():
12001205
shape = state.image_shape
12011206

@@ -1237,15 +1242,17 @@ def update_segmentation(ret_val):
12371242
self._viewer.layers["current_object"].data[seg == 1] = state.current_track_id
12381243
self._viewer.layers["current_object"].refresh()
12391244

1240-
worker = tracking_impl()
1241-
worker.returned.connect(update_segmentation)
1242-
worker.start()
1243-
return worker
1245+
ret_val = tracking_impl()
1246+
update_segmentation(ret_val)
1247+
# worker = tracking_impl()
1248+
# worker.returned.connect(update_segmentation)
1249+
# worker.start()
1250+
# return worker
12441251

12451252
def _run_volumetric_segmentation(self):
12461253
pbar, pbar_signals = _create_pbar_for_threadworker()
12471254

1248-
@thread_worker
1255+
# @thread_worker
12491256
def volumetric_segmentation_impl():
12501257
state = AnnotatorState()
12511258
shape = state.image_shape
@@ -1277,10 +1284,13 @@ def update_segmentation(seg):
12771284
self._viewer.layers["current_object"].data = seg
12781285
self._viewer.layers["current_object"].refresh()
12791286

1280-
worker = volumetric_segmentation_impl()
1281-
worker.returned.connect(update_segmentation)
1282-
worker.start()
1283-
return worker
1287+
seg = volumetric_segmentation_impl()
1288+
self._viewer.layers["current_object"].data = seg
1289+
self._viewer.layers["current_object"].refresh()
1290+
# worker = volumetric_segmentation_impl()
1291+
# worker.returned.connect(update_segmentation)
1292+
# worker.start()
1293+
# return worker
12841294

12851295
def __call__(self):
12861296
if _validate_embeddings(self._viewer):
@@ -1522,7 +1532,7 @@ def _empty_segmentation_warning(self):
15221532
def _run_segmentation_2d(self, kwargs, i=None):
15231533
pbar, pbar_signals = _create_pbar_for_threadworker()
15241534

1525-
@thread_worker
1535+
# @thread_worker
15261536
def seg_impl():
15271537
def pbar_init(total, description):
15281538
pbar_signals.pbar_total.emit(total)
@@ -1548,10 +1558,12 @@ def update_segmentation(seg):
15481558
self._viewer.layers["auto_segmentation"].data[i] = seg
15491559
self._viewer.layers["auto_segmentation"].refresh()
15501560

1551-
worker = seg_impl()
1552-
worker.returned.connect(update_segmentation)
1553-
worker.start()
1554-
return worker
1561+
seg = seg_impl()
1562+
update_segmentation(seg)
1563+
# worker = seg_impl()
1564+
# worker.returned.connect(update_segmentation)
1565+
# worker.start()
1566+
# return worker
15551567

15561568
# We refuse to run 3D segmentation with the AMG unless we have a GPU or all embeddings
15571569
# are precomputed. Otherwise this would take too long.
@@ -1578,7 +1590,7 @@ def _run_segmentation_3d(self, kwargs):
15781590

15791591
pbar, pbar_signals = _create_pbar_for_threadworker()
15801592

1581-
@thread_worker
1593+
# @thread_worker
15821594
def seg_impl():
15831595
segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data)
15841596
offset = 0
@@ -1617,10 +1629,12 @@ def update_segmentation(segmentation):
16171629
self._viewer.layers["auto_segmentation"].data = segmentation
16181630
self._viewer.layers["auto_segmentation"].refresh()
16191631

1620-
worker = seg_impl()
1621-
worker.returned.connect(update_segmentation)
1622-
worker.start()
1623-
return worker
1632+
seg = seg_impl()
1633+
update_segmentation(seg)
1634+
# worker = seg_impl()
1635+
# worker.returned.connect(update_segmentation)
1636+
# worker.start()
1637+
# return worker
16241638

16251639
def __call__(self):
16261640
if _validate_embeddings(self._viewer):

micro_sam/sam_annotator/training_ui.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33

44
from qtpy import QtWidgets
5-
from napari.qt.threading import thread_worker
5+
# from napari.qt.threading import thread_worker
66

77
import torch
88
from torch.utils.data import random_split
@@ -238,7 +238,7 @@ def __call__(self, skip_validate=False):
238238
else:
239239
checkpoint_path = self.checkpoint
240240

241-
@thread_worker()
241+
# @thread_worker()
242242
def run_training():
243243
train_loader, val_loader = self._get_loaders()
244244
train_sam_for_configuration(
@@ -296,7 +296,9 @@ def run_training():
296296
pbar_signals.pbar_stop.emit()
297297
return export_checkpoint
298298

299-
worker = run_training()
300-
worker.returned.connect(lambda path: print(f"Training has finished. The trained model is saved at {path}."))
301-
worker.start()
302-
return worker
299+
path = run_training()
300+
print(f"Training has finished. The trained model is saved at {path}.")
301+
# worker = run_training()
302+
# worker.returned.connect(lambda path: print(f"Training has finished. The trained model is saved at {path}."))
303+
# worker.start()
304+
# return worker

micro_sam/sam_annotator/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None:
119119
viewer.layers["point_prompts"].refresh()
120120
if "prompts" in viewer.layers:
121121
# Select all prompts and then remove them.
122+
# This is how it worked before napari 0.5.
123+
# viewer.layers["prompts"].data = []
122124
viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data)))
123125
viewer.layers["prompts"].remove_selected()
124126
viewer.layers["prompts"].refresh()

0 commit comments

Comments
 (0)