Skip to content

Commit 214cdde

Browse files
Updates for the documentation
1 parent 0301642 commit 214cdde

File tree

5 files changed

+60
-49
lines changed

5 files changed

+60
-49
lines changed

doc/annotation_tools.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Data Annotation Tools

doc/installaiton.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Installation

doc/start_page.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
Segment Anything for Microscopy implements automatic and interactive annotation for microscopy data. It is built on top of [Segment Anything](https://segment-anything.com/) by Meta AI and specializes it for microscopy and other bio-imaging data.
44
Its core components are:
5-
- The `micro_sam` tool: implements interactive data annotation using [napari](https://napari.org/stable/).
6-
- The `micro_sam` python library: implements functionality for applying Segment Anything to multi-dimensional data, and to fine-tune it on custom datasets.
7-
- The `micro_sam` models: new Segment Anything models that were fine-tuned on publicly available microscopy data.
5+
- The `micro_sam` annotator tools: interactive data annotation with [napari](https://napari.org/stable/) applications.
6+
- The `micro_sam` library: apply Segment Anything to multi-dimensional data or fine-tune it on your data.
7+
- The `micro_sam` models: Segment Anything models fine-tuned on publicly available microscopy data.
8+
9+
## Quickstart
10+
11+
## Citation

micro_sam/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""
2+
.. include:: ../doc/start_page.md
3+
.. include:: ../doc/annotation_tools.md
4+
.. include:: ../doc/installaiton.md
5+
"""

micro_sam/instance_segmentation.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,14 @@ def _postprocess_batch(
135135

136136
# calculate stability score
137137
data["stability_score"] = amg_utils.calculate_stability_score(
138-
data["masks"], self.predictor.model.mask_threshold, stability_score_offset
138+
data["masks"], self._predictor.model.mask_threshold, stability_score_offset
139139
)
140140
if stability_score_thresh > 0.0:
141141
keep_mask = data["stability_score"] >= stability_score_thresh
142142
data.filter(keep_mask)
143143

144144
# threshold masks and calculate boxes
145-
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
145+
data["masks"] = data["masks"] > self._predictor.model.mask_threshold
146146
data["boxes"] = amg_utils.batched_mask_to_box(data["masks"])
147147

148148
# filter boxes that touch crop boundaries
@@ -327,19 +327,19 @@ def __init__(
327327
else:
328328
raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
329329

330-
self.predictor = predictor
331-
self.points_per_side = points_per_side
332-
self.points_per_batch = points_per_batch
333-
self.crop_n_layers = crop_n_layers
334-
self.crop_overlap_ratio = crop_overlap_ratio
335-
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
330+
self._predictor = predictor
331+
self._points_per_side = points_per_side
332+
self._points_per_batch = points_per_batch
333+
self._crop_n_layers = crop_n_layers
334+
self._crop_overlap_ratio = crop_overlap_ratio
335+
self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
336336

337337
def _process_batch(self, points, im_size):
338338
# run model on this batch
339-
transformed_points = self.predictor.transform.apply_coords(points, im_size)
340-
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
339+
transformed_points = self._predictor.transform.apply_coords(points, im_size)
340+
in_points = torch.as_tensor(transformed_points, device=self._predictor.device)
341341
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
342-
masks, iou_preds, _ = self.predictor.predict_torch(
342+
masks, iou_preds, _ = self._predictor.predict_torch(
343343
in_points[:, None, :],
344344
in_labels[:, None],
345345
multimask_output=True,
@@ -363,18 +363,18 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
363363
cropped_im_size = cropped_im.shape[:2]
364364

365365
if not precomputed_embeddings:
366-
self.predictor.set_image(cropped_im)
366+
self._predictor.set_image(cropped_im)
367367

368368
# get the points for this crop
369369
points_scale = np.array(cropped_im_size)[None, ::-1]
370370
points_for_image = self.point_grids[crop_layer_idx] * points_scale
371371

372372
# generate masks for this crop in batches
373373
data = amg_utils.MaskData()
374-
n_batches = len(points_for_image) // self.points_per_batch +\
375-
int(len(points_for_image) % self.points_per_batch != 0)
374+
n_batches = len(points_for_image) // self._points_per_batch +\
375+
int(len(points_for_image) % self._points_per_batch != 0)
376376
for (points,) in tqdm(
377-
amg_utils.batch_iterator(self.points_per_batch, points_for_image),
377+
amg_utils.batch_iterator(self._points_per_batch, points_for_image),
378378
disable=not verbose, total=n_batches,
379379
desc="Predict masks for point grid prompts",
380380
):
@@ -383,7 +383,7 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
383383
del batch_data
384384

385385
if not precomputed_embeddings:
386-
self.predictor.reset_image()
386+
self._predictor.reset_image()
387387

388388
return data
389389

@@ -407,16 +407,16 @@ def initialize(
407407
"""
408408
original_size = image.shape[:2]
409409
crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
410-
original_size, self.crop_n_layers, self.crop_overlap_ratio
410+
original_size, self._crop_n_layers, self._crop_overlap_ratio
411411
)
412412

413413
# we can set fixed image embeddings if we only have a single crop box
414414
# (which is the default setting)
415415
# otherwise we have to recompute the embeddings for each crop and can't precompute
416416
if len(crop_boxes) == 1:
417417
if image_embeddings is None:
418-
image_embeddings = util.precompute_image_embeddings(self.predictor, image)
419-
util.set_precomputed(self.predictor, image_embeddings, i=i)
418+
image_embeddings = util.precompute_image_embeddings(self._predictor, image)
419+
util.set_precomputed(self._predictor, image_embeddings, i=i)
420420
precomputed_embeddings = True
421421
else:
422422
precomputed_embeddings = False
@@ -537,33 +537,33 @@ def __init__(
537537
):
538538
super().__init__()
539539

540-
self.predictor = predictor
541-
self.offsets = self.default_offsets if offsets is None else offsets
542-
self.min_initial_size = min_initial_size
543-
self.distance_type = distance_type
544-
self.bias = bias
545-
self.use_box = use_box
546-
self.use_mask = use_mask
547-
self.use_points = use_points
548-
self.box_extension = box_extension
540+
self._predictor = predictor
541+
self._offsets = self.default_offsets if offsets is None else offsets
542+
self._min_initial_size = min_initial_size
543+
self._distance_type = distance_type
544+
self._bias = bias
545+
self._use_box = use_box
546+
self._use_mask = use_mask
547+
self._use_points = use_points
548+
self._box_extension = box_extension
549549

550550
# additional state that is set 'initialize'
551551
self._initial_segmentation = None
552552

553553
def _compute_initial_segmentation(self):
554554

555-
embeddings = self.predictor.get_image_embedding().squeeze().cpu().numpy()
555+
embeddings = self._predictor.get_image_embedding().squeeze().cpu().numpy()
556556
assert embeddings.shape == (256, 64, 64), f"{embeddings.shape}"
557557

558558
initial_segmentation = embed.segment_embeddings_mws(
559-
embeddings, distance_type=self.distance_type, offsets=self.offsets, bias=self.bias,
559+
embeddings, distance_type=self._distance_type, offsets=self._offsets, bias=self._bias,
560560
).astype("uint32")
561561
assert initial_segmentation.shape == (64, 64), f"{initial_segmentation.shape}"
562562

563563
# filter out small initial objects
564-
if self.min_initial_size > 0:
564+
if self._min_initial_size > 0:
565565
seg_ids, sizes = np.unique(initial_segmentation, return_counts=True)
566-
initial_segmentation[np.isin(initial_segmentation, seg_ids[sizes < self.min_initial_size])] = 0
566+
initial_segmentation[np.isin(initial_segmentation, seg_ids[sizes < self._min_initial_size])] = 0
567567

568568
# resize to 256 x 256, which is the mask input expected by SAM
569569
initial_segmentation = resize(
@@ -582,10 +582,10 @@ def _compute_mask_data(self, initial_segmentation, original_size, verbose):
582582
for seg_id in tqdm(seg_ids, disable=not verbose, desc="Compute masks from initial segmentation"):
583583
mask = initial_segmentation == seg_id
584584
masks, iou_preds, _ = segment_from_mask(
585-
self.predictor, mask, original_size=original_size,
585+
self._predictor, mask, original_size=original_size,
586586
multimask_output=True, return_logits=True, return_all=True,
587-
use_box=self.use_box, use_mask=self.use_mask, use_points=self.use_points,
588-
box_extension=self.box_extension,
587+
use_box=self._use_box, use_mask=self._use_mask, use_points=self._use_points,
588+
box_extension=self._box_extension,
589589
)
590590
data = amg_utils.MaskData(
591591
masks=torch.from_numpy(masks),
@@ -618,8 +618,8 @@ def initialize(
618618
original_size = image.shape[:2]
619619

620620
if image_embeddings is None:
621-
image_embeddings = util.precompute_image_embeddings(self.predictor, image,)
622-
util.set_precomputed(self.predictor, image_embeddings, i=i)
621+
image_embeddings = util.precompute_image_embeddings(self._predictor, image,)
622+
util.set_precomputed(self._predictor, image_embeddings, i=i)
623623

624624
# compute the initial segmentation via embedding based MWS and then refine the masks
625625
# with the segment anything model
@@ -737,8 +737,8 @@ def __init__(
737737
**kwargs
738738
):
739739
super().__init__(predictor=predictor, **kwargs)
740-
self.n_threads = n_threads
741-
self.with_background = with_background
740+
self._n_threads = n_threads
741+
self._with_background = with_background
742742

743743
# additional state for 'initialize'
744744
self._tile_shape = None
@@ -758,10 +758,10 @@ def segment_tile(tile_id):
758758
"input_size": tile_features.attrs["input_size"],
759759
"original_size": tile_features.attrs["original_size"]
760760
}
761-
util.set_precomputed(self.predictor, tile_image_embeddings, i)
761+
util.set_precomputed(self._predictor, tile_image_embeddings, i)
762762
return self._compute_initial_segmentation()
763763

764-
with futures.ThreadPoolExecutor(self.n_threads) as tp:
764+
with futures.ThreadPoolExecutor(self._n_threads) as tp:
765765
initial_segmentations = list(tqdm(
766766
tp.map(segment_tile, range(n_tiles)), disable=not verbose, total=n_tiles,
767767
desc="Tile-based initial segmentation"
@@ -781,7 +781,7 @@ def _compute_mask_data_tiled(self, image_embeddings, i, initial_segmentations, n
781781
"input_size": tile_features.attrs["input_size"],
782782
"original_size": this_tile_shape
783783
}
784-
util.set_precomputed(self.predictor, tile_image_embeddings, i)
784+
util.set_precomputed(self._predictor, tile_image_embeddings, i)
785785
tile_data = self._compute_mask_data(initial_segmentations[tile_id], this_tile_shape, verbose=False)
786786
mask_data.append(tile_data)
787787

@@ -821,7 +821,7 @@ def initialize(
821821
"Embeddings with tiling can only be computed if a save path is given."
822822
)
823823
image_embeddings = util.precompute_image_embeddings(
824-
self.predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path
824+
self._predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path
825825
)
826826
elif image_embeddings is None and not have_tiling_params:
827827
raise ValueError("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)")
@@ -903,12 +903,12 @@ def segment_tile(_, tile_id):
903903
mask_data = self._postprocess_masks(
904904
mask_data, 0, box_nms_thresh, box_nms_thresh, output_mode="binary_mask"
905905
)
906-
mask_data = mask_data_to_segmentation(mask_data, this_tile_shape, with_background=self.with_background)
906+
mask_data = mask_data_to_segmentation(mask_data, this_tile_shape, with_background=self._with_background)
907907
return mask_data
908908

909909
input_ = _FakeInput(self.original_size)
910910
segmentation = stitch_segmentation(
911-
input_, segment_tile, self._tile_shape, self._halo, with_background=self.with_background, verbose=verbose
911+
input_, segment_tile, self._tile_shape, self._halo, with_background=self._with_background, verbose=verbose
912912
)
913913

914914
if min_mask_region_area > 0:
@@ -940,7 +940,7 @@ def segment_tile(_, tile_id):
940940
initial_segmentation = stitch_segmentation(
941941
input_, segment_tile,
942942
self._tile_shape, self._halo,
943-
with_background=self.with_background, verbose=False
943+
with_background=self._with_background, verbose=False
944944
)
945945

946946
self._stitched_initial_segmentation = initial_segmentation

0 commit comments

Comments
 (0)