Skip to content

Commit f2b20e0

Browse files
Add doc strings and type annotations for prompt_generators
1 parent 3aac00d commit f2b20e0

File tree

3 files changed

+92
-30
lines changed

3 files changed

+92
-30
lines changed

micro_sam/instance_segmentation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import multiprocessing as mp
22
import warnings
3-
from abc import ABC, Mapping
3+
from abc import ABC
4+
from collections.abc import Mapping
45
from concurrent import futures
56
from copy import deepcopy
67
from typing import Any, List, Optional
@@ -1055,8 +1056,6 @@ def segment_tile(_, tile_id):
10551056
# this is still experimental and not yet ready to be integrated within the annotator_3d
10561057
# (will need to see how well it works with retrained models)
10571058
def _segment_instances_from_embeddings_3d(predictor, image_embeddings, verbose=1, iou_threshold=0.50, **kwargs):
1058-
"""
1059-
"""
10601059
if image_embeddings["original_size"] is None: # tiled embeddings
10611060
is_tiled = True
10621061
image_shape = tuple(image_embeddings["features"].attrs["shape"])

micro_sam/prompt_generators.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,76 @@
1+
from collections.abc import Mapping
2+
from typing import Optional
3+
14
import numpy as np
25
from scipy.ndimage import binary_dilation
36

47

58
class PointAndBoxPromptGenerator:
6-
def __init__(self, n_positive_points, n_negative_points, dilation_strength,
7-
get_point_prompts=True, get_box_prompts=False):
9+
"""Generate point and/or box prompts from an instance segmentation.
10+
11+
You can use this class to derive prompts from an instance segmentation, either for
12+
evaluation purposes or for training Segment Anything on custom data.
13+
In order to use this generator you need to precompute the bounding boxes and center
14+
coordiantes of the instance segmentation, using e.g. `util.get_bounding_boxes_and_centers`.
15+
Here's an example for how to use this class:
16+
```python
17+
# Initialize generator for 1 positive and 4 negative point prompts.
18+
prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)
19+
# Precompute the bounding boxes for the given segmentation
20+
bounding_boxes, _ = util.get_bounding_boxes_and_centes(segmentation)
21+
# generate point prompts for the object with id 1 in 'segmentation'
22+
seg_id = 1
23+
points, point_labels, _, _ = prompt_generator(segmentation, seg_id, bounding_boxes)
24+
```
25+
26+
Args:
27+
n_positive_points: The number of positive point prompts to generate per mask.
28+
n_negative_points: The number of negative point prompts to generate per mask.
29+
dilation_strength: The factor by which the mask is dilated before generating prompts.
30+
get_point_prompts: Whether to generate point prompts.
31+
get_box_prompts: Whether to generate box prompts.
32+
"""
33+
def __init__(
34+
self,
35+
n_positive_points: int,
36+
n_negative_points: int,
37+
dilation_strength: int,
38+
get_point_prompts: bool = True,
39+
get_box_prompts: bool = False
40+
):
841
self.n_positive_points = n_positive_points
942
self.n_negative_points = n_negative_points
1043
self.dilation_strength = dilation_strength
1144
self.get_box_prompts = get_box_prompts
1245
self.get_point_prompts = get_point_prompts
1346

1447
if self.get_point_prompts is False and self.get_box_prompts is False:
15-
raise ValueError("You need to request for box/point prompts or both")
16-
17-
def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
18-
"""
19-
Parameters:
20-
gt: True Labels
21-
gt_id: Instance ID for the Cells
22-
center_coordinates: Coordinates for the centroid seed of the cell
23-
bbox_coordinates: Bounding box coordinates around the cell
48+
raise ValueError("You need to request box prompts, point prompts or both.")
49+
50+
def __call__(
51+
self,
52+
segmentation: np.ndarray,
53+
segmentation_id: int,
54+
bbox_coordinates: Mapping[int, tuple],
55+
center_coordinates: Optional[Mapping[int, np.ndarray]] = None
56+
) -> tuple[
57+
Optional[list[tuple]], Optional[list[int]], Optional[list[tuple]], np.ndarray
58+
]:
59+
"""Generate the prompts for one object in the segmentation.
60+
61+
Args:
62+
segmentation: The instance segmentation.
63+
segmentation_id: The ID of the instance.
64+
bbox_coordinates: The precomputed bounding boxes of all objects in the segmentation.
65+
center_coordinates: The precomputed center coordinates of all objects in the segmentation.
66+
If passed, these coordinates will be used as the first positive point prompt.
67+
If not passed a random point from within the object mask will be used.
68+
69+
Returns:
70+
List of point coordinates. Returns None, if get_point_prompts is false.
71+
List of point labels. Returns None, if get_point_prompts is false.
72+
List containing the object bounding box. Returns None, if get_box_prompts is false.
73+
Object mask.
2474
"""
2575
coord_list = []
2676
label_list = []
@@ -30,7 +80,8 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
3080
coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float
3181
label_list.append(1)
3282

33-
# getting the additional positive points by randomly sampling points from this mask except the center coordinate
83+
# getting the additional positive points by randomly sampling points
84+
# from this mask except the center coordinate
3485
n_positive_remaining = self.n_positive_points - 1
3586

3687
else:
@@ -40,7 +91,7 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
4091
if self.get_box_prompts:
4192
bbox_list = [bbox_coordinates]
4293

43-
object_mask = gt == gt_id
94+
object_mask = segmentation == segmentation_id
4495

4596
if n_positive_remaining > 0:
4697
# all coordinates of our current object
@@ -50,9 +101,10 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
50101
n_coordinates = len(object_coordinates[0])
51102

52103
# randomly sampling n_positive_remaining_points from these coordinates
53-
positive_indices = np.random.choice(n_coordinates, replace=False,
54-
size=min(n_positive_remaining, n_coordinates) # handles the cases with insufficient fg pixels
55-
)
104+
positive_indices = np.random.choice(
105+
n_coordinates, replace=False,
106+
size=min(n_positive_remaining, n_coordinates) # handles the cases with insufficient fg pixels
107+
)
56108
for positive_index in positive_indices:
57109
positive_coordinates = int(object_coordinates[0][positive_index]), \
58110
int(object_coordinates[1][positive_index])
@@ -63,10 +115,10 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
63115
# getting the negative points
64116
# for this we do the opposite and we set the mask to the bounding box - the object mask
65117
# we need to dilate the object mask before doing this: we use scipy.ndimage.binary_dilation for this
66-
dilated_object = binary_dilation(object_mask, iterations=self.dilation_strength)
67-
background_mask = np.zeros(gt.shape)
118+
dilated_object = binary_dilation(object_mask, iterations=self.dilation_strensegmentationh)
119+
background_mask = np.zeros(segmentation.shape)
68120
background_mask[bbox_coordinates[0]:bbox_coordinates[2], bbox_coordinates[1]:bbox_coordinates[3]] = 1
69-
background_mask = binary_dilation(background_mask, iterations=self.dilation_strength)
121+
background_mask = binary_dilation(background_mask, iterations=self.dilation_strensegmentationh)
70122
background_mask = abs(
71123
background_mask.astype(np.float32) - dilated_object.astype(np.float32)
72124
) # casting booleans to do subtraction
@@ -80,9 +132,10 @@ def __call__(self, gt, gt_id, bbox_coordinates, center_coordinates=None):
80132
n_coordinates = len(background_coordinates[0])
81133

82134
# randomly sample n_positive_remaining_points from these coordinates
83-
negative_indices = np.random.choice(n_coordinates, replace=False,
84-
size=min(n_negative_remaining, n_coordinates) # handles the cases with insufficient bg pixels
85-
)
135+
negative_indices = np.random.choice(
136+
n_coordinates, replace=False,
137+
size=min(n_negative_remaining, n_coordinates) # handles the cases with insufficient bg pixels
138+
)
86139
for negative_index in negative_indices:
87140
negative_coordinates = int(background_coordinates[0][negative_index]), \
88141
int(background_coordinates[1][negative_index])

micro_sam/util.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,18 +430,28 @@ def compute_iou(mask1, mask2):
430430
return iou
431431

432432

433-
def get_cell_center_coordinates(gt, mode="v"):
434-
"""
435-
Returns the center coordinates of the foreground instances in the ground-truth
433+
def get_bounding_boxes_and_centers(
434+
segmentation: np.ndarray,
435+
mode: str = "v"
436+
) -> tuple[Mapping[int, np.ndarray], Mapping[int, tuple]]:
437+
"""Returns the center coordinates of the foreground instances in the ground-truth.
438+
439+
Args:
440+
segmentation:
441+
mode:
442+
443+
Returns:
444+
A dictionary that maps object ids to the corresponding centroid.
445+
A dictionary that maps object_ids to the corresponding bounding box.
436446
"""
437447
assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
438448

439-
properties = regionprops(gt)
449+
properties = regionprops(segmentation)
440450

441451
if mode == "p":
442452
center_coordinates = {prop.label: prop.centroid for prop in properties}
443453
elif mode == "v":
444-
center_coordinates = vigra.filters.eccentricityCenters(gt.astype('float32'))
454+
center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
445455
center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
446456

447457
bbox_coordinates = {prop.label: prop.bbox for prop in properties}

0 commit comments

Comments
 (0)