Skip to content

Commit 3814ccf

Browse files
Fixing bugs and adding docs to SAM. (#80)
Fixing bugs and adding docs and unit tests.
1 parent a8b5560 commit 3814ccf

File tree

4 files changed

+193
-73
lines changed

4 files changed

+193
-73
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Segment Anything
2+
================
3+
4+
.. py:module:: tfimm.architectures.segment_anything.sam
5+
6+
.. automodule:: tfimm.architectures.segment_anything.sam
7+
8+
.. autoclass:: SegmentAnythingModelConfig
9+
.. autoclass:: SegmentAnythingModel
10+
:members: grid_size, mask_size, mask_threshold, dummy_inputs, call
11+
12+
.. py:module:: tfimm.architectures.segment_anything.predictor
13+
14+
.. autoclass:: SAMPredictor
15+
:members: set_image, clear_image, preprocess_masks, __call__
16+
.. autoclass:: ImageResizer
17+
:members: scale_to_size, scale_image, unscale_image, pad_image, scale_points,
18+
scale_boxes, postprocess_mask

tests/models/test_segment_anything.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from tfimm.architectures.segment_anything import (
99
ImageResizer,
10+
SAMPredictor,
1011
SegmentAnythingModel,
1112
SegmentAnythingModelConfig,
1213
)
@@ -452,3 +453,18 @@ def test_transfer_weights():
452453
res_2 = model_2.image_encoder(img, training=False).numpy()
453454

454455
np.testing.assert_almost_equal(res_1, res_2, decimal=5)
456+
457+
458+
@pytest.mark.parametrize("fixed_input_size", [True, False])
459+
def test_predictor(fixed_input_size):
460+
sam = create_model("sam_vit_test_model", fixed_input_size=fixed_input_size)
461+
cast(SegmentAnythingModel, sam)
462+
predictor = SAMPredictor(model=sam)
463+
464+
# We use something different from input_size, because the predictor should deal
465+
# with resizing.
466+
img = np.random.rand(20, 20, 3)
467+
predictor.set_image(img)
468+
masks, scores, logits = predictor(points=[[10, 10]], multimask_output=False)
469+
470+
assert masks.shape == (1, *img.shape[:2])

tfimm/architectures/segment_anything/predictor.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# TODO: Test notebook with Colab
22
# TODO: Add "Open in Colab" badge to notebook (see SAM)
33
# TODO: Test mixed precision behaviour
4-
5-
# TODO: Compile documentation
64
# TODO: Convert PT models to TF and upload to GitHub
75
import math
86
from typing import Callable, Optional, Tuple
@@ -17,8 +15,9 @@
1715

1816
class SAMPredictor:
1917
"""
20-
Uses SAM to calculate the image embedding for an image, and then allows
21-
repeated, efficient mask prediction given prompts.
18+
User-friendly interface to the Segment Anything model. Uses SAM to calculate the
19+
image embedding for an image, and then allows repeated, efficient mask prediction
20+
given prompts.
2221
2322
While internally TF is used for inference, the inputs and return values in this
2423
class are numpy arrays for ease of use.
@@ -47,15 +46,21 @@ def __init__(
4746

4847
def set_image(self, image: np.ndarray):
4948
"""
50-
Calculates the image embeddings for the provided image, allowing masks to be
51-
predicted much faster.
49+
Calculates and stores the image embeddings for the provided image, allowing
50+
masks to be predicted much faster.
5251
5352
Args:
54-
image: An array of shape (H, W, C) with pixel values in [0, 255].
53+
image: An array of shape (H, W, C) with pixel values in [0, 255]. The image
54+
can be any shape, and it will be resized and padded to the model input
55+
shape as necessary.
56+
57+
Returns:
58+
Nothing. The image embedding and resizing information are stored in the
59+
class.
5560
"""
5661
if self.model.cfg.fixed_input_size:
5762
self.resizer = ImageResizer(
58-
src_size=image.shape[:2], dst_size=self.self.model.cfg.input_size
63+
src_size=image.shape[:2], dst_size=self.model.cfg.input_size
5964
)
6065
else:
6166
# If the model allows flexible input sizes, we simply pad the image to
@@ -87,6 +92,22 @@ def clear_image(self):
8792
self.image_embedding = None
8893
self.image_set = False
8994

95+
def input_size(self):
96+
"""Returns the input size to the model."""
97+
if self.image_set:
98+
return self.resizer.dst_size
99+
elif self.model.cfg.fixed_input_size:
100+
return self.model.cfg.input_size
101+
else:
102+
raise ValueError(
103+
"To determine model input size need to set image or use a model with "
104+
"a fixed input size."
105+
)
106+
107+
def mask_size(self):
108+
"""Returns the mask prompt input size to the model."""
109+
return self.model.mask_size(self.input_size())
110+
90111
def preprocess_masks(self, mask: np.ndarray) -> np.ndarray:
91112
"""
92113
Preprocesses a mask from the pixel space of the original image (H0, W0), to the
@@ -106,7 +127,7 @@ def preprocess_masks(self, mask: np.ndarray) -> np.ndarray:
106127
mask = self.resizer.pad_image(mask, channels_last=False)
107128

108129
# Then we rescale to mask_size
109-
mask_size = (self.resizer.dst_size[0] // 4, self.resizer.dst_size[1] // 4)
130+
mask_size = self.mask_size()
110131
mask = self.resizer.scale_to_size(mask, size=mask_size, channels_last=False)
111132
return mask
112133

@@ -124,10 +145,10 @@ def __call__(
124145
already been set.
125146
126147
The original image size is (H0, W0). After resizing and padding the image size
127-
becomes (H, W) as given by `input_size` (usually (1024, 1024)). Mask input and
128-
logit output will have shape (H', W') given by `mask_size` (usually H'=H/4).
148+
becomes (H, W) given by ``input_size`` (usually (1024, 1024)). Mask input and
149+
logit output will have shape (H', W') given by ``mask_size`` (usually H'=H/4).
129150
130-
One can use `preprocess_masks` to transform an input mask from (H0, W0) to
151+
One can use ``preprocess_masks`` to transform an input mask from (H0, W0) to
131152
(H', W').
132153
133154
Prompts can also be batched, i.e., have the shape (N, M1, 2) for points;
@@ -149,18 +170,23 @@ def __call__(
149170
return_logits: If True, we don't threshold the upscaled mask.
150171
151172
Returns:
152-
masks: A (K, H, W) bool tensor of binary masked predictions, where K is
153-
determined by the multimask_output parameter. It is either 1, if
154-
``multimask_output=False`` or given by the ``nb_multimask_outputs``
155-
parameter in the model configuration.
156-
scores: An (K,) array with the model's predictions of mask quality.
157-
logits: An (K, H', W') array with low resoulution logits, where usually
158-
H'=H/4 and W'=W/4. This can be passed as mask input to subsequent
159-
iterations of prediction.
173+
* Masks, an (K, H, W) bool array of binary masked predictions, where K is
174+
determined by the multimask_output parameter. It is either 1, if
175+
``multimask_output=False`` or given by the ``nb_multimask_outputs``
176+
parameter in the model configuration.
177+
* Scores, an (K,) array with the model's predictions of mask quality.
178+
* Logits, an (K, H', W') array with low resoulution logits, where usually
179+
H'=H/4 and W'=W/4. This can be passed as mask input to subsequent
180+
iterations of prediction.
160181
"""
161182
if not self.image_set:
162183
raise ValueError("Need to set image before calling predict().")
163184

185+
points = np.asarray(points) if points is not None else None
186+
labels = np.asarray(labels) if labels is not None else None
187+
boxes = np.asarray(boxes) if boxes is not None else None
188+
masks = np.asarray(masks) if masks is not None else None
189+
164190
batch_shape = self._batch_shape(points, labels, boxes, masks)
165191

166192
if points is None:
@@ -170,7 +196,7 @@ def __call__(
170196
if boxes is None:
171197
boxes = np.zeros(batch_shape + (0, 4), dtype=np.float32)
172198
if masks is None:
173-
mask_size = (self.resizer.dst_size[0] // 4, self.resizer.dst_size[1] // 4)
199+
mask_size = self.mask_size()
174200
masks = np.zeros(batch_shape + (0, *mask_size), dtype=np.float32)
175201

176202
# Check that batch shapes are compatible
@@ -240,7 +266,9 @@ def _predict_tf(self, points, labels, boxes, masks, multimask_output):
240266
multimask_output=multimask_output,
241267
)
242268

243-
masks = self.model._postprocess_logits(logits, return_logits=True)
269+
masks = self.model.postprocess_logits(
270+
logits, input_size=self.input_size(), return_logits=True
271+
)
244272
return masks, scores, logits
245273

246274
@staticmethod
@@ -263,6 +291,14 @@ class ImageResizer:
263291
Utility class to resize images to the largest side that fits in a given shape while
264292
preserving the aspect ratio. It also provides methods to resize coordinates and
265293
bounding boxes and pad images.
294+
295+
Args:
296+
src_size: Size of image before resizing. The resize object is image
297+
specific, i.e., for each source image size it is recommended to create
298+
a new ``ImageResizer`` object.
299+
dst_size: The target size after resizing (and padding).
300+
pad_only: If True, we don't do any resizing and only pad the image to
301+
``dst_size``.
266302
"""
267303

268304
def __init__(
@@ -271,24 +307,13 @@ def __init__(
271307
dst_size: Tuple[int, int],
272308
pad_only: bool = False,
273309
):
274-
"""
275-
Creates an ``ImageResizer`` object.
276-
277-
Args:
278-
src_size: Size of image before resizing. The resize object is image
279-
specific, i.e., for each source image size it is recommended to create
280-
a new ``ImageResizer`` object.
281-
dst_size: The target size after resizing (and padding).
282-
pad_only: If True, we don't do any resizing and only pad the image to
283-
``dst_size``.
284-
"""
285310
self.src_size = src_size
286311
self.dst_size = dst_size
287312
self.pad_only = pad_only
288313

289314
self.scale, self.rescaled_size = self._get_scale()
290315

291-
def _get_scale(self):
316+
def _get_scale(self) -> Tuple[float, Tuple[int, int]]:
292317
"""Calculate rescaling parameters."""
293318
if self.pad_only:
294319
# If we only pad, then scale is 1 and the rescaled size equal input size.
@@ -421,7 +446,7 @@ def pad_image(self, image: np.ndarray, channels_last: bool = True) -> np.ndarray
421446

422447
return image
423448

424-
def scale_points(self, points):
449+
def scale_points(self, points: np.ndarray) -> np.ndarray:
425450
"""
426451
Scale points by the same factor as the image.
427452
@@ -433,7 +458,7 @@ def scale_points(self, points):
433458
"""
434459
return self.scale * points
435460

436-
def scale_boxes(self, boxes):
461+
def scale_boxes(self, boxes: np.ndarray) -> np.ndarray:
437462
"""
438463
Scale bounding boxes by the same factor as the image.
439464
@@ -445,7 +470,9 @@ def scale_boxes(self, boxes):
445470
"""
446471
return self.scale * boxes
447472

448-
def postprocess_mask(self, mask, threshold: Optional[float] = None):
473+
def postprocess_mask(
474+
self, mask: np.ndarray, threshold: Optional[float] = None
475+
) -> np.ndarray:
449476
"""
450477
Convert an upscaled segmentation mask from ``dst_size`` back to ``src_size``
451478
by removing padding and unscaling.

0 commit comments

Comments
 (0)