Skip to content

Commit 5f115ee

Browse files
Fix and update amg state precomputation in 2d annotator
1 parent b71999b commit 5f115ee

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,11 @@ def _segment_widget(v: Viewer) -> None:
4040
v.layers["current_object"].refresh()
4141

4242

43-
def _get_amg(is_tiled, with_background=True, box_extension=0.05):
43+
def _get_amg(is_tiled):
4444
if is_tiled:
45-
amg = instance_segmentation.TiledAutomaticMaskGenerator(
46-
PREDICTOR, with_background=with_background, box_extension=box_extension,
47-
)
45+
amg = instance_segmentation.TiledAutomaticMaskGenerator(PREDICTOR)
4846
else:
49-
amg = instance_segmentation.EmbeddingMaskGenerator(PREDICTOR, box_extension=box_extension)
47+
amg = instance_segmentation.AutomaticMaskGenerator(PREDICTOR)
5048
return amg
5149

5250

@@ -64,7 +62,8 @@ def _autosegment_widget(
6462
v: Viewer,
6563
pred_iou_thresh: float = 0.88,
6664
stability_score_thresh: float = 0.95,
67-
min_object_size: int = 25,
65+
min_object_size: int = 100,
66+
with_background: bool = True,
6867
) -> None:
6968
global AMG
7069
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
@@ -203,7 +202,7 @@ def _precompute_amg_state(raw, save_path):
203202
AMG.set_state(amg_state)
204203
return
205204

206-
print("Precomputing the state for instance segmentation")
205+
print("Precomputing the state for instance segmentation.")
207206
AMG.initialize(raw, image_embeddings=IMAGE_EMBEDDINGS, verbose=True)
208207
with open(save_path_amg, "wb") as f:
209208
pickle.dump(AMG.get_state(), f)
@@ -220,6 +219,7 @@ def annotator_2d(
220219
return_viewer: bool = False,
221220
v: Optional[Viewer] = None,
222221
predictor: Optional[SamPredictor] = None,
222+
precompute_amg_state: bool = False,
223223
) -> Optional[Viewer]:
224224
"""The 2d annotation tool.
225225
@@ -242,6 +242,9 @@ def annotator_2d(
242242
This enables using a pre-initialized viewer, for example in `sam_annotator.image_series_annotator`.
243243
predictor: The Segment Anything model. Passing this enables using fully custom models.
244244
If you pass `predictor` then `model_type` will be ignored.
245+
precompute_amg_state: Whether to precompute the state for automatic mask generation.
246+
This will take more time when precomputing embeddings, but will then make
247+
automatic mask generation much faster.
245248
246249
Returns:
247250
The napari viewer, only returned if `return_viewer=True`.
@@ -254,11 +257,12 @@ def annotator_2d(
254257
PREDICTOR = util.get_sam_model(model_type=model_type)
255258
else:
256259
PREDICTOR = predictor
260+
257261
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
258262
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo,
259263
wrong_file_callback=show_wrong_file_warning
260264
)
261-
if embedding_path is not None:
265+
if precompute_amg_state and (embedding_path is not None):
262266
_precompute_amg_state(raw, embedding_path)
263267

264268
# we set the pre-computed image embeddings if we don't use tiling
@@ -287,6 +291,7 @@ def annotator_2d(
287291
def main():
288292
"""@private"""
289293
parser = vutil._initialize_parser(description="Run interactive segmentation for an image.")
294+
parser.add_argument("--precompute_amg_state", action="store_true")
290295
args = parser.parse_args()
291296
raw = util.load_image_data(args.input, key=args.key)
292297

@@ -302,4 +307,5 @@ def main():
302307
raw, embedding_path=args.embedding_path,
303308
show_embeddings=args.show_embeddings, segmentation_result=segmentation_result,
304309
model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
310+
precompute_amg_state=args.precompute_amg_state,
305311
)

0 commit comments

Comments
 (0)