11# TODO: Test notebook with Colab
22# TODO: Add "Open in Colab" badge to notebook (see SAM)
33# TODO: Test mixed precision behaviour
4-
5- # TODO: Compile documentation
64# TODO: Convert PT models to TF and upload to GitHub
75import math
86from typing import Callable , Optional , Tuple
1715
1816class SAMPredictor :
1917 """
20- Uses SAM to calculate the image embedding for an image, and then allows
21- repeated, efficient mask prediction given prompts.
18+ User-friendly interface to the Segment Anything model. Uses SAM to calculate the
19+ image embedding for an image, and then allows repeated, efficient mask prediction
20+ given prompts.
2221
2322 While internally TF is used for inference, the inputs and return values in this
2423 class are numpy arrays for ease of use.
@@ -47,15 +46,21 @@ def __init__(
4746
4847 def set_image (self , image : np .ndarray ):
4948 """
50- Calculates the image embeddings for the provided image, allowing masks to be
51- predicted much faster.
49+ Calculates and stores the image embeddings for the provided image, allowing
50+ masks to be predicted much faster.
5251
5352 Args:
54- image: An array of shape (H, W, C) with pixel values in [0, 255].
53+ image: An array of shape (H, W, C) with pixel values in [0, 255]. The image
54+ can be any shape, and it will be resized and padded to the model input
55+ shape as necessary.
56+
57+ Returns:
58+ Nothing. The image embedding and resizing information are stored in the
59+ class.
5560 """
5661 if self .model .cfg .fixed_input_size :
5762 self .resizer = ImageResizer (
58- src_size = image .shape [:2 ], dst_size = self .self . model .cfg .input_size
63+ src_size = image .shape [:2 ], dst_size = self .model .cfg .input_size
5964 )
6065 else :
6166 # If the model allows flexible input sizes, we simply pad the image to
@@ -87,6 +92,22 @@ def clear_image(self):
8792 self .image_embedding = None
8893 self .image_set = False
8994
95+ def input_size (self ):
96+ """Returns the input size to the model."""
97+ if self .image_set :
98+ return self .resizer .dst_size
99+ elif self .model .cfg .fixed_input_size :
100+ return self .model .cfg .input_size
101+ else :
102+ raise ValueError (
103+ "To determine model input size need to set image or use a model with "
104+ "a fixed input size."
105+ )
106+
107+ def mask_size (self ):
108+ """Returns the mask prompt input size to the model."""
109+ return self .model .mask_size (self .input_size ())
110+
90111 def preprocess_masks (self , mask : np .ndarray ) -> np .ndarray :
91112 """
92113 Preprocesses a mask from the pixel space of the original image (H0, W0), to the
@@ -106,7 +127,7 @@ def preprocess_masks(self, mask: np.ndarray) -> np.ndarray:
106127 mask = self .resizer .pad_image (mask , channels_last = False )
107128
108129 # Then we rescale to mask_size
109- mask_size = ( self .resizer . dst_size [ 0 ] // 4 , self . resizer . dst_size [ 1 ] // 4 )
130+ mask_size = self .mask_size ( )
110131 mask = self .resizer .scale_to_size (mask , size = mask_size , channels_last = False )
111132 return mask
112133
@@ -124,10 +145,10 @@ def __call__(
124145 already been set.
125146
126147 The original image size is (H0, W0). After resizing and padding the image size
127- becomes (H, W) as given by `input_size` (usually (1024, 1024)). Mask input and
128- logit output will have shape (H', W') given by `mask_size` (usually H'=H/4).
148+ becomes (H, W) given by `` input_size` ` (usually (1024, 1024)). Mask input and
149+ logit output will have shape (H', W') given by `` mask_size` ` (usually H'=H/4).
129150
130- One can use `preprocess_masks` to transform an input mask from (H0, W0) to
151+ One can use `` preprocess_masks` ` to transform an input mask from (H0, W0) to
131152 (H', W').
132153
133154 Prompts can also be batched, i.e., have the shape (N, M1, 2) for points;
@@ -149,18 +170,23 @@ def __call__(
149170 return_logits: If True, we don't threshold the upscaled mask.
150171
151172 Returns:
152- masks: A (K, H, W) bool tensor of binary masked predictions, where K is
153- determined by the multimask_output parameter. It is either 1, if
154- ``multimask_output=False`` or given by the ``nb_multimask_outputs``
155- parameter in the model configuration.
156- scores: An (K,) array with the model's predictions of mask quality.
157- logits: An (K, H', W') array with low resoulution logits, where usually
158- H'=H/4 and W'=W/4. This can be passed as mask input to subsequent
159- iterations of prediction.
173+ * Masks, an (K, H, W) bool array of binary masked predictions, where K is
174+ determined by the multimask_output parameter. It is either 1, if
175+ ``multimask_output=False`` or given by the ``nb_multimask_outputs``
176+ parameter in the model configuration.
177+ * Scores, an (K,) array with the model's predictions of mask quality.
178+ * Logits, an (K, H', W') array with low resoulution logits, where usually
179+ H'=H/4 and W'=W/4. This can be passed as mask input to subsequent
180+ iterations of prediction.
160181 """
161182 if not self .image_set :
162183 raise ValueError ("Need to set image before calling predict()." )
163184
185+ points = np .asarray (points ) if points is not None else None
186+ labels = np .asarray (labels ) if labels is not None else None
187+ boxes = np .asarray (boxes ) if boxes is not None else None
188+ masks = np .asarray (masks ) if masks is not None else None
189+
164190 batch_shape = self ._batch_shape (points , labels , boxes , masks )
165191
166192 if points is None :
@@ -170,7 +196,7 @@ def __call__(
170196 if boxes is None :
171197 boxes = np .zeros (batch_shape + (0 , 4 ), dtype = np .float32 )
172198 if masks is None :
173- mask_size = ( self .resizer . dst_size [ 0 ] // 4 , self . resizer . dst_size [ 1 ] // 4 )
199+ mask_size = self .mask_size ( )
174200 masks = np .zeros (batch_shape + (0 , * mask_size ), dtype = np .float32 )
175201
176202 # Check that batch shapes are compatible
@@ -240,7 +266,9 @@ def _predict_tf(self, points, labels, boxes, masks, multimask_output):
240266 multimask_output = multimask_output ,
241267 )
242268
243- masks = self .model ._postprocess_logits (logits , return_logits = True )
269+ masks = self .model .postprocess_logits (
270+ logits , input_size = self .input_size (), return_logits = True
271+ )
244272 return masks , scores , logits
245273
246274 @staticmethod
@@ -263,6 +291,14 @@ class ImageResizer:
263291 Utility class to resize images to the largest side that fits in a given shape while
264292 preserving the aspect ratio. It also provides methods to resize coordinates and
265293 bounding boxes and pad images.
294+
295+ Args:
296+ src_size: Size of image before resizing. The resize object is image
297+ specific, i.e., for each source image size it is recommended to create
298+ a new ``ImageResizer`` object.
299+ dst_size: The target size after resizing (and padding).
300+ pad_only: If True, we don't do any resizing and only pad the image to
301+ ``dst_size``.
266302 """
267303
268304 def __init__ (
@@ -271,24 +307,13 @@ def __init__(
271307 dst_size : Tuple [int , int ],
272308 pad_only : bool = False ,
273309 ):
274- """
275- Creates an ``ImageResizer`` object.
276-
277- Args:
278- src_size: Size of image before resizing. The resize object is image
279- specific, i.e., for each source image size it is recommended to create
280- a new ``ImageResizer`` object.
281- dst_size: The target size after resizing (and padding).
282- pad_only: If True, we don't do any resizing and only pad the image to
283- ``dst_size``.
284- """
285310 self .src_size = src_size
286311 self .dst_size = dst_size
287312 self .pad_only = pad_only
288313
289314 self .scale , self .rescaled_size = self ._get_scale ()
290315
291- def _get_scale (self ):
316+ def _get_scale (self ) -> Tuple [ float , Tuple [ int , int ]] :
292317 """Calculate rescaling parameters."""
293318 if self .pad_only :
294319 # If we only pad, then scale is 1 and the rescaled size equal input size.
@@ -421,7 +446,7 @@ def pad_image(self, image: np.ndarray, channels_last: bool = True) -> np.ndarray
421446
422447 return image
423448
424- def scale_points (self , points ) :
449+ def scale_points (self , points : np . ndarray ) -> np . ndarray :
425450 """
426451 Scale points by the same factor as the image.
427452
@@ -433,7 +458,7 @@ def scale_points(self, points):
433458 """
434459 return self .scale * points
435460
436- def scale_boxes (self , boxes ) :
461+ def scale_boxes (self , boxes : np . ndarray ) -> np . ndarray :
437462 """
438463 Scale bounding boxes by the same factor as the image.
439464
@@ -445,7 +470,9 @@ def scale_boxes(self, boxes):
445470 """
446471 return self .scale * boxes
447472
448- def postprocess_mask (self , mask , threshold : Optional [float ] = None ):
473+ def postprocess_mask (
474+ self , mask : np .ndarray , threshold : Optional [float ] = None
475+ ) -> np .ndarray :
449476 """
450477 Convert an upscaled segmentation mask from ``dst_size`` back to ``src_size``
451478 by removing padding and unscaling.
0 commit comments