Skip to content

Commit a1ce544

Browse files
Update instance segmentation to be more memory efficient
1 parent e758122 commit a1ce544

File tree

1 file changed

+68
-63
lines changed

1 file changed

+68
-63
lines changed

micro_sam/instance_segmentation.py

Lines changed: 68 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def _postprocess_batch(
135135
original_size,
136136
pred_iou_thresh,
137137
stability_score_thresh,
138-
stability_score_offset,
139138
box_nms_thresh,
140139
):
141140
orig_h, orig_w = original_size
@@ -145,28 +144,16 @@ def _postprocess_batch(
145144
keep_mask = data["iou_preds"] > pred_iou_thresh
146145
data.filter(keep_mask)
147146

148-
# calculate stability score
149-
data["stability_score"] = amg_utils.calculate_stability_score(
150-
data["masks"], self._predictor.model.mask_threshold, stability_score_offset
151-
)
147+
# filter by stability score
152148
if stability_score_thresh > 0.0:
153149
keep_mask = data["stability_score"] >= stability_score_thresh
154150
data.filter(keep_mask)
155151

156-
# threshold masks and calculate boxes
157-
data["masks"] = data["masks"] > self._predictor.model.mask_threshold
158-
data["boxes"] = amg_utils.batched_mask_to_box(data["masks"])
159-
160152
# filter boxes that touch crop boundaries
161153
keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
162154
if not torch.all(keep_mask):
163155
data.filter(keep_mask)
164156

165-
# compress to RLE
166-
data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
167-
data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
168-
del data["masks"]
169-
170157
# remove duplicates within this crop.
171158
keep_by_nms = batched_nms(
172159
data["boxes"].float(),
@@ -267,6 +254,32 @@ def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, cr
267254

268255
return curr_anns
269256

257+
def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
258+
orig_h, orig_w = original_size
259+
260+
# serialize predictions and store in MaskData
261+
data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
262+
if points is not None:
263+
data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0))
264+
265+
del masks
266+
267+
# calculate the stability scores
268+
data["stability_score"] = amg_utils.calculate_stability_score(
269+
data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
270+
)
271+
272+
# threshold masks and calculate boxes
273+
data["masks"] = data["masks"] > self._predictor.model.mask_threshold
274+
data["boxes"] = amg_utils.batched_mask_to_box(data["masks"])
275+
276+
# compress to RLE
277+
data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
278+
data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
279+
del data["masks"]
280+
281+
return data
282+
270283
def get_state(self) -> Dict[str, Any]:
271284
"""Get the initialized state of the mask generator.
272285
@@ -315,6 +328,7 @@ class AutomaticMaskGenerator(AMGBase):
315328
crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
316329
point_grids: A lisst over explicit grids of points used for sampling masks.
317330
Normalized to [0, 1] with respect to the image coordinate system.
331+
stability_score_offset: The amount to shift the cutoff when calculating the stability score.
318332
"""
319333
def __init__(
320334
self,
@@ -325,6 +339,7 @@ def __init__(
325339
crop_overlap_ratio: float = 512 / 1500,
326340
crop_n_points_downscale_factor: int = 1,
327341
point_grids: Optional[List[np.ndarray]] = None,
342+
stability_score_offset: float = 1.0,
328343
):
329344
super().__init__()
330345

@@ -345,8 +360,9 @@ def __init__(
345360
self._crop_n_layers = crop_n_layers
346361
self._crop_overlap_ratio = crop_overlap_ratio
347362
self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
363+
self._stability_score_offset = stability_score_offset
348364

349-
def _process_batch(self, points, im_size):
365+
def _process_batch(self, points, im_size, crop_box, original_size):
350366
# run model on this batch
351367
transformed_points = self._predictor.transform.apply_coords(points, im_size)
352368
in_points = torch.as_tensor(transformed_points, device=self._predictor.device)
@@ -357,24 +373,14 @@ def _process_batch(self, points, im_size):
357373
multimask_output=True,
358374
return_logits=True,
359375
)
360-
361-
# serialize predictions and store in MaskData
362-
data = amg_utils.MaskData(
363-
masks=masks.flatten(0, 1),
364-
iou_preds=iou_preds.flatten(0, 1),
365-
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
366-
)
376+
data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
367377
del masks
368-
369378
return data
370379

371380
def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
372381
# crop the image and calculate embeddings
373-
if crop_box is None:
374-
cropped_im = image
375-
else:
376-
x0, y0, x1, y1 = crop_box
377-
cropped_im = image[y0:y1, x0:x1, :]
382+
x0, y0, x1, y1 = crop_box
383+
cropped_im = image[y0:y1, x0:x1, :]
378384
cropped_im_size = cropped_im.shape[:2]
379385

380386
if not precomputed_embeddings:
@@ -393,7 +399,7 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
393399
disable=not verbose, total=n_batches,
394400
desc="Predict masks for point grid prompts",
395401
):
396-
batch_data = self._process_batch(points, cropped_im_size)
402+
batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
397403
data.cat(batch_data)
398404
del batch_data
399405

@@ -421,6 +427,8 @@ def initialize(
421427
verbose: Whether to print computation progress.
422428
"""
423429
original_size = image.shape[:2]
430+
self._original_size = original_size
431+
424432
crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
425433
original_size, self._crop_n_layers, self._crop_overlap_ratio
426434
)
@@ -449,14 +457,12 @@ def initialize(
449457
self._is_initialized = True
450458
self._crop_list = crop_list
451459
self._crop_boxes = crop_boxes
452-
self._original_size = original_size
453460

454461
@torch.no_grad()
455462
def generate(
456463
self,
457464
pred_iou_thresh: float = 0.88,
458465
stability_score_thresh: float = 0.95,
459-
stability_score_offset: float = 1.0,
460466
box_nms_thresh: float = 0.7,
461467
crop_nms_thresh: float = 0.7,
462468
min_mask_region_area: int = 0,
@@ -468,7 +474,6 @@ def generate(
468474
pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
469475
stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
470476
under changes to the cutoff used to binarize the model prediction.
471-
stability_score_offset: The amount to shift the cutoff when calculating the stability score.
472477
box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
473478
crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
474479
min_mask_region_area: Minimal size for the predicted masks.
@@ -487,7 +492,6 @@ def generate(
487492
crop_box=crop_box, original_size=self.original_size,
488493
pred_iou_thresh=pred_iou_thresh,
489494
stability_score_thresh=stability_score_thresh,
490-
stability_score_offset=stability_score_offset,
491495
box_nms_thresh=box_nms_thresh
492496
)
493497
data.cat(crop_data)
@@ -535,6 +539,7 @@ class EmbeddingMaskGenerator(AMGBase):
535539
use_mask: Whether to use the initial segments as prompts.
536540
use_points: Whether to use points derived from the initial segments as prompts.
537541
box_extension: Factor for extending the bounding box prompts, given in the relative box size.
542+
stability_score_offset: The amount to shift the cutoff when calculating the stability score.
538543
"""
539544
default_offsets = [[-1, 0], [0, -1], [-3, 0], [0, -3], [-9, 0], [0, -9]]
540545

@@ -549,6 +554,7 @@ def __init__(
549554
use_mask: bool = True,
550555
use_points: bool = False,
551556
box_extension: float = 0.05,
557+
stability_score_offset: float = 1.0,
552558
):
553559
super().__init__()
554560

@@ -561,6 +567,7 @@ def __init__(
561567
self._use_mask = use_mask
562568
self._use_points = use_points
563569
self._box_extension = box_extension
570+
self._stability_score_offset = stability_score_offset
564571

565572
# additional state that is set 'initialize'
566573
self._initial_segmentation = None
@@ -587,7 +594,7 @@ def _compute_initial_segmentation(self):
587594

588595
return initial_segmentation
589596

590-
def _compute_mask_data(self, initial_segmentation, original_size, verbose):
597+
def _compute_mask_data(self, initial_segmentation, crop_box, original_size, verbose):
591598
seg_ids = np.unique(initial_segmentation)
592599
if seg_ids[0] == 0:
593600
seg_ids = seg_ids[1:]
@@ -602,11 +609,9 @@ def _compute_mask_data(self, initial_segmentation, original_size, verbose):
602609
use_box=self._use_box, use_mask=self._use_mask, use_points=self._use_points,
603610
box_extension=self._box_extension,
604611
)
605-
data = amg_utils.MaskData(
606-
masks=torch.from_numpy(masks),
607-
iou_preds=torch.from_numpy(iou_preds),
608-
seg_id=torch.from_numpy(np.full(len(masks), seg_id, dtype="int64")),
609-
)
612+
# bring masks and iou_preds to a format compatible with _to_mask_data
613+
masks, iou_preds = torch.from_numpy(masks[None]), torch.from_numpy(iou_preds[None])
614+
data = self._to_mask_data(masks, iou_preds, crop_box, original_size)
610615
del masks
611616
mask_data.cat(data)
612617

@@ -631,6 +636,11 @@ def initialize(
631636
verbose: Whether to print computation progress.
632637
"""
633638
original_size = image.shape[:2]
639+
self._original_size = original_size
640+
641+
# the crop box is always the full image
642+
crop_box = [0, 0, original_size[1], original_size[0]]
643+
self._crop_boxes = [crop_box]
634644

635645
if image_embeddings is None:
636646
image_embeddings = util.precompute_image_embeddings(self._predictor, image,)
@@ -639,26 +649,20 @@ def initialize(
639649
# compute the initial segmentation via embedding based MWS and then refine the masks
640650
# with the segment anything model
641651
initial_segmentation = self._compute_initial_segmentation()
642-
mask_data = self._compute_mask_data(initial_segmentation, original_size, verbose)
652+
mask_data = self._compute_mask_data(initial_segmentation, crop_box, original_size, verbose)
643653
# to be compatible with the file format of the super class we have to wrap the mask data in a list
644654
crop_list = [mask_data]
645655

646656
# set the initialized data
647657
self._is_initialized = True
648658
self._initial_segmentation = initial_segmentation
649659
self._crop_list = crop_list
650-
# the crop box is always the full image
651-
self._crop_boxes = [
652-
[0, 0, original_size[1], original_size[0]]
653-
]
654-
self._original_size = original_size
655660

656661
@torch.no_grad()
657662
def generate(
658663
self,
659664
pred_iou_thresh: float = 0.88,
660665
stability_score_thresh: float = 0.95,
661-
stability_score_offset: float = 1.0,
662666
box_nms_thresh: float = 0.7,
663667
min_mask_region_area: int = 0,
664668
output_mode: str = "binary_mask",
@@ -669,7 +673,6 @@ def generate(
669673
pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
670674
stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
671675
under changes to the cutoff used to binarize the model prediction.
672-
stability_score_offset: The amount to shift the cutoff when calculating the stability score.
673676
box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
674677
min_mask_region_area: Minimal size for the predicted masks.
675678
output_mode: The form masks are returned in.
@@ -685,7 +688,6 @@ def generate(
685688
original_size=self.original_size,
686689
pred_iou_thresh=pred_iou_thresh,
687690
stability_score_thresh=stability_score_thresh,
688-
stability_score_offset=stability_score_offset,
689691
box_nms_thresh=box_nms_thresh
690692
)
691693

@@ -777,6 +779,7 @@ class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
777779
Higher numbers may be faster but use more GPU memory.
778780
point_grids: A lisst over explicit grids of points used for sampling masks.
779781
Normalized to [0, 1] with respect to the image coordinate system.
782+
stability_score_offset: The amount to shift the cutoff when calculating the stability score.
780783
"""
781784

782785
# We only expose the arguments that make sense for the tiled mask generator.
@@ -788,12 +791,14 @@ def __init__(
788791
points_per_side: Optional[int] = 32,
789792
points_per_batch: int = 64,
790793
point_grids: Optional[List[np.ndarray]] = None,
794+
stability_score_offset: float = 1.0,
791795
) -> None:
792796
super().__init__(
793797
predictor=predictor,
794798
points_per_side=points_per_side,
795799
points_per_batch=points_per_batch,
796800
point_grids=point_grids,
801+
stability_score_offset=stability_score_offset,
797802
)
798803

799804
@torch.no_grad()
@@ -821,20 +826,24 @@ def initialize(
821826
embedding_save_path: Where to save the image embeddings.
822827
"""
823828
original_size = image.shape[:2]
829+
self._original_size = original_size
830+
824831
image_embeddings, tile_shape, halo = _compute_tiled_embeddings(
825832
self._predictor, image, image_embeddings, embedding_save_path, tile_shape, halo
826833
)
827834

828835
tiling = blocking([0, 0], original_size, tile_shape)
829836
n_tiles = tiling.numberOfBlocks
830837

838+
# the crop box is always the full local tile
839+
tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
840+
crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
841+
842+
# we need to cast to the image representation that is compatible with SAM
843+
image = util._to_image(image)
844+
831845
mask_data = []
832846
for tile_id in tqdm(range(n_tiles), total=n_tiles, desc="Compute masks for tile", disable=not verbose):
833-
# get the bounding box for this tile and crop the image data
834-
tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
835-
tile_bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
836-
tile_data = image[tile_bb]
837-
838847
# set the pre-computed embeddings for this tile
839848
features = image_embeddings["features"][tile_id]
840849
tile_embeddings = {
@@ -846,18 +855,14 @@ def initialize(
846855

847856
# compute the mask data for this tile and append it
848857
this_mask_data = self._process_crop(
849-
tile_data, crop_box=None, crop_layer_idx=0, verbose=verbose, precomputed_embeddings=True
858+
image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, verbose=verbose, precomputed_embeddings=True
850859
)
851860
mask_data.append(this_mask_data)
852861

853862
# set the initialized data
854863
self._is_initialized = True
855864
self._crop_list = mask_data
856-
self._original_size = original_size
857-
858-
# the crop box is always the full local tile
859-
tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
860-
self._crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
865+
self._crop_boxes = crop_boxes
861866

862867

863868
class TiledEmbeddingMaskGenerator(EmbeddingMaskGenerator):
@@ -924,7 +929,10 @@ def _compute_mask_data_tiled(self, image_embeddings, i, initial_segmentations, n
924929
"original_size": this_tile_shape
925930
}
926931
util.set_precomputed(self._predictor, tile_image_embeddings, i)
927-
tile_data = self._compute_mask_data(initial_segmentations[tile_id], this_tile_shape, verbose=False)
932+
this_crop_box = [0, 0, this_tile_shape[1], this_tile_shape[0]]
933+
tile_data = self._compute_mask_data(
934+
initial_segmentations[tile_id], this_crop_box, this_tile_shape, verbose=False
935+
)
928936
mask_data.append(tile_data)
929937

930938
return mask_data
@@ -982,7 +990,6 @@ def generate(
982990
self,
983991
pred_iou_thresh: float = 0.88,
984992
stability_score_thresh: float = 0.95,
985-
stability_score_offset: float = 1.0,
986993
box_nms_thresh: float = 0.7,
987994
min_mask_region_area: int = 0,
988995
verbose: bool = False
@@ -993,7 +1000,6 @@ def generate(
9931000
pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
9941001
stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
9951002
under changes to the cutoff used to binarize the model prediction.
996-
stability_score_offset: The amount to shift the cutoff when calculating the stability score.
9971003
box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
9981004
min_mask_region_area: Minimal size for the predicted masks.
9991005
verbose: Whether to print progress of the computation.
@@ -1014,7 +1020,6 @@ def segment_tile(_, tile_id):
10141020
data=mask_data, crop_box=crop_box, original_size=this_tile_shape,
10151021
pred_iou_thresh=pred_iou_thresh,
10161022
stability_score_thresh=stability_score_thresh,
1017-
stability_score_offset=stability_score_offset,
10181023
box_nms_thresh=box_nms_thresh,
10191024
)
10201025
mask_data.to_numpy()

0 commit comments

Comments
 (0)