@@ -135,7 +135,6 @@ def _postprocess_batch(
135135 original_size ,
136136 pred_iou_thresh ,
137137 stability_score_thresh ,
138- stability_score_offset ,
139138 box_nms_thresh ,
140139 ):
141140 orig_h , orig_w = original_size
@@ -145,28 +144,16 @@ def _postprocess_batch(
145144 keep_mask = data ["iou_preds" ] > pred_iou_thresh
146145 data .filter (keep_mask )
147146
148- # calculate stability score
149- data ["stability_score" ] = amg_utils .calculate_stability_score (
150- data ["masks" ], self ._predictor .model .mask_threshold , stability_score_offset
151- )
147+ # filter by stability score
152148 if stability_score_thresh > 0.0 :
153149 keep_mask = data ["stability_score" ] >= stability_score_thresh
154150 data .filter (keep_mask )
155151
156- # threshold masks and calculate boxes
157- data ["masks" ] = data ["masks" ] > self ._predictor .model .mask_threshold
158- data ["boxes" ] = amg_utils .batched_mask_to_box (data ["masks" ])
159-
160152 # filter boxes that touch crop boundaries
161153 keep_mask = ~ amg_utils .is_box_near_crop_edge (data ["boxes" ], crop_box , [0 , 0 , orig_w , orig_h ])
162154 if not torch .all (keep_mask ):
163155 data .filter (keep_mask )
164156
165- # compress to RLE
166- data ["masks" ] = amg_utils .uncrop_masks (data ["masks" ], crop_box , orig_h , orig_w )
167- data ["rles" ] = amg_utils .mask_to_rle_pytorch (data ["masks" ])
168- del data ["masks" ]
169-
170157 # remove duplicates within this crop.
171158 keep_by_nms = batched_nms (
172159 data ["boxes" ].float (),
@@ -267,6 +254,32 @@ def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, cr
267254
268255 return curr_anns
269256
257+ def _to_mask_data (self , masks , iou_preds , crop_box , original_size , points = None ):
258+ orig_h , orig_w = original_size
259+
260+ # serialize predictions and store in MaskData
261+ data = amg_utils .MaskData (masks = masks .flatten (0 , 1 ), iou_preds = iou_preds .flatten (0 , 1 ))
262+ if points is not None :
263+ data ["points" ] = torch .as_tensor (points .repeat (masks .shape [1 ], axis = 0 ))
264+
265+ del masks
266+
267+ # calculate the stability scores
268+ data ["stability_score" ] = amg_utils .calculate_stability_score (
269+ data ["masks" ], self ._predictor .model .mask_threshold , self ._stability_score_offset
270+ )
271+
272+ # threshold masks and calculate boxes
273+ data ["masks" ] = data ["masks" ] > self ._predictor .model .mask_threshold
274+ data ["boxes" ] = amg_utils .batched_mask_to_box (data ["masks" ])
275+
276+ # compress to RLE
277+ data ["masks" ] = amg_utils .uncrop_masks (data ["masks" ], crop_box , orig_h , orig_w )
278+ data ["rles" ] = amg_utils .mask_to_rle_pytorch (data ["masks" ])
279+ del data ["masks" ]
280+
281+ return data
282+
270283 def get_state (self ) -> Dict [str , Any ]:
271284 """Get the initialized state of the mask generator.
272285
@@ -315,6 +328,7 @@ class AutomaticMaskGenerator(AMGBase):
315328 crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
316329 point_grids: A lisst over explicit grids of points used for sampling masks.
317330 Normalized to [0, 1] with respect to the image coordinate system.
331+ stability_score_offset: The amount to shift the cutoff when calculating the stability score.
318332 """
319333 def __init__ (
320334 self ,
@@ -325,6 +339,7 @@ def __init__(
325339 crop_overlap_ratio : float = 512 / 1500 ,
326340 crop_n_points_downscale_factor : int = 1 ,
327341 point_grids : Optional [List [np .ndarray ]] = None ,
342+ stability_score_offset : float = 1.0 ,
328343 ):
329344 super ().__init__ ()
330345
@@ -345,8 +360,9 @@ def __init__(
345360 self ._crop_n_layers = crop_n_layers
346361 self ._crop_overlap_ratio = crop_overlap_ratio
347362 self ._crop_n_points_downscale_factor = crop_n_points_downscale_factor
363+ self ._stability_score_offset = stability_score_offset
348364
349- def _process_batch (self , points , im_size ):
365+ def _process_batch (self , points , im_size , crop_box , original_size ):
350366 # run model on this batch
351367 transformed_points = self ._predictor .transform .apply_coords (points , im_size )
352368 in_points = torch .as_tensor (transformed_points , device = self ._predictor .device )
@@ -357,24 +373,14 @@ def _process_batch(self, points, im_size):
357373 multimask_output = True ,
358374 return_logits = True ,
359375 )
360-
361- # serialize predictions and store in MaskData
362- data = amg_utils .MaskData (
363- masks = masks .flatten (0 , 1 ),
364- iou_preds = iou_preds .flatten (0 , 1 ),
365- points = torch .as_tensor (points .repeat (masks .shape [1 ], axis = 0 )),
366- )
376+ data = self ._to_mask_data (masks , iou_preds , crop_box , original_size , points = points )
367377 del masks
368-
369378 return data
370379
371380 def _process_crop (self , image , crop_box , crop_layer_idx , verbose , precomputed_embeddings ):
372381 # crop the image and calculate embeddings
373- if crop_box is None :
374- cropped_im = image
375- else :
376- x0 , y0 , x1 , y1 = crop_box
377- cropped_im = image [y0 :y1 , x0 :x1 , :]
382+ x0 , y0 , x1 , y1 = crop_box
383+ cropped_im = image [y0 :y1 , x0 :x1 , :]
378384 cropped_im_size = cropped_im .shape [:2 ]
379385
380386 if not precomputed_embeddings :
@@ -393,7 +399,7 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
393399 disable = not verbose , total = n_batches ,
394400 desc = "Predict masks for point grid prompts" ,
395401 ):
396- batch_data = self ._process_batch (points , cropped_im_size )
402+ batch_data = self ._process_batch (points , cropped_im_size , crop_box , self . original_size )
397403 data .cat (batch_data )
398404 del batch_data
399405
@@ -421,6 +427,8 @@ def initialize(
421427 verbose: Whether to print computation progress.
422428 """
423429 original_size = image .shape [:2 ]
430+ self ._original_size = original_size
431+
424432 crop_boxes , layer_idxs = amg_utils .generate_crop_boxes (
425433 original_size , self ._crop_n_layers , self ._crop_overlap_ratio
426434 )
@@ -449,14 +457,12 @@ def initialize(
449457 self ._is_initialized = True
450458 self ._crop_list = crop_list
451459 self ._crop_boxes = crop_boxes
452- self ._original_size = original_size
453460
454461 @torch .no_grad ()
455462 def generate (
456463 self ,
457464 pred_iou_thresh : float = 0.88 ,
458465 stability_score_thresh : float = 0.95 ,
459- stability_score_offset : float = 1.0 ,
460466 box_nms_thresh : float = 0.7 ,
461467 crop_nms_thresh : float = 0.7 ,
462468 min_mask_region_area : int = 0 ,
@@ -468,7 +474,6 @@ def generate(
468474 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
469475 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
470476 under changes to the cutoff used to binarize the model prediction.
471- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
472477 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
473478 crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
474479 min_mask_region_area: Minimal size for the predicted masks.
@@ -487,7 +492,6 @@ def generate(
487492 crop_box = crop_box , original_size = self .original_size ,
488493 pred_iou_thresh = pred_iou_thresh ,
489494 stability_score_thresh = stability_score_thresh ,
490- stability_score_offset = stability_score_offset ,
491495 box_nms_thresh = box_nms_thresh
492496 )
493497 data .cat (crop_data )
@@ -535,6 +539,7 @@ class EmbeddingMaskGenerator(AMGBase):
535539 use_mask: Whether to use the initial segments as prompts.
536540 use_points: Whether to use points derived from the initial segments as prompts.
537541 box_extension: Factor for extending the bounding box prompts, given in the relative box size.
542+ stability_score_offset: The amount to shift the cutoff when calculating the stability score.
538543 """
539544 default_offsets = [[- 1 , 0 ], [0 , - 1 ], [- 3 , 0 ], [0 , - 3 ], [- 9 , 0 ], [0 , - 9 ]]
540545
@@ -549,6 +554,7 @@ def __init__(
549554 use_mask : bool = True ,
550555 use_points : bool = False ,
551556 box_extension : float = 0.05 ,
557+ stability_score_offset : float = 1.0 ,
552558 ):
553559 super ().__init__ ()
554560
@@ -561,6 +567,7 @@ def __init__(
561567 self ._use_mask = use_mask
562568 self ._use_points = use_points
563569 self ._box_extension = box_extension
570+ self ._stability_score_offset = stability_score_offset
564571
565572 # additional state that is set 'initialize'
566573 self ._initial_segmentation = None
@@ -587,7 +594,7 @@ def _compute_initial_segmentation(self):
587594
588595 return initial_segmentation
589596
590- def _compute_mask_data (self , initial_segmentation , original_size , verbose ):
597+ def _compute_mask_data (self , initial_segmentation , crop_box , original_size , verbose ):
591598 seg_ids = np .unique (initial_segmentation )
592599 if seg_ids [0 ] == 0 :
593600 seg_ids = seg_ids [1 :]
@@ -602,11 +609,9 @@ def _compute_mask_data(self, initial_segmentation, original_size, verbose):
602609 use_box = self ._use_box , use_mask = self ._use_mask , use_points = self ._use_points ,
603610 box_extension = self ._box_extension ,
604611 )
605- data = amg_utils .MaskData (
606- masks = torch .from_numpy (masks ),
607- iou_preds = torch .from_numpy (iou_preds ),
608- seg_id = torch .from_numpy (np .full (len (masks ), seg_id , dtype = "int64" )),
609- )
612+ # bring masks and iou_preds to a format compatible with _to_mask_data
613+ masks , iou_preds = torch .from_numpy (masks [None ]), torch .from_numpy (iou_preds [None ])
614+ data = self ._to_mask_data (masks , iou_preds , crop_box , original_size )
610615 del masks
611616 mask_data .cat (data )
612617
@@ -631,6 +636,11 @@ def initialize(
631636 verbose: Whether to print computation progress.
632637 """
633638 original_size = image .shape [:2 ]
639+ self ._original_size = original_size
640+
641+ # the crop box is always the full image
642+ crop_box = [0 , 0 , original_size [1 ], original_size [0 ]]
643+ self ._crop_boxes = [crop_box ]
634644
635645 if image_embeddings is None :
636646 image_embeddings = util .precompute_image_embeddings (self ._predictor , image ,)
@@ -639,26 +649,20 @@ def initialize(
639649 # compute the initial segmentation via embedding based MWS and then refine the masks
640650 # with the segment anything model
641651 initial_segmentation = self ._compute_initial_segmentation ()
642- mask_data = self ._compute_mask_data (initial_segmentation , original_size , verbose )
652+ mask_data = self ._compute_mask_data (initial_segmentation , crop_box , original_size , verbose )
643653 # to be compatible with the file format of the super class we have to wrap the mask data in a list
644654 crop_list = [mask_data ]
645655
646656 # set the initialized data
647657 self ._is_initialized = True
648658 self ._initial_segmentation = initial_segmentation
649659 self ._crop_list = crop_list
650- # the crop box is always the full image
651- self ._crop_boxes = [
652- [0 , 0 , original_size [1 ], original_size [0 ]]
653- ]
654- self ._original_size = original_size
655660
656661 @torch .no_grad ()
657662 def generate (
658663 self ,
659664 pred_iou_thresh : float = 0.88 ,
660665 stability_score_thresh : float = 0.95 ,
661- stability_score_offset : float = 1.0 ,
662666 box_nms_thresh : float = 0.7 ,
663667 min_mask_region_area : int = 0 ,
664668 output_mode : str = "binary_mask" ,
@@ -669,7 +673,6 @@ def generate(
669673 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
670674 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
671675 under changes to the cutoff used to binarize the model prediction.
672- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
673676 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
674677 min_mask_region_area: Minimal size for the predicted masks.
675678 output_mode: The form masks are returned in.
@@ -685,7 +688,6 @@ def generate(
685688 original_size = self .original_size ,
686689 pred_iou_thresh = pred_iou_thresh ,
687690 stability_score_thresh = stability_score_thresh ,
688- stability_score_offset = stability_score_offset ,
689691 box_nms_thresh = box_nms_thresh
690692 )
691693
@@ -777,6 +779,7 @@ class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
777779 Higher numbers may be faster but use more GPU memory.
778780 point_grids: A lisst over explicit grids of points used for sampling masks.
779781 Normalized to [0, 1] with respect to the image coordinate system.
782+ stability_score_offset: The amount to shift the cutoff when calculating the stability score.
780783 """
781784
782785 # We only expose the arguments that make sense for the tiled mask generator.
@@ -788,12 +791,14 @@ def __init__(
788791 points_per_side : Optional [int ] = 32 ,
789792 points_per_batch : int = 64 ,
790793 point_grids : Optional [List [np .ndarray ]] = None ,
794+ stability_score_offset : float = 1.0 ,
791795 ) -> None :
792796 super ().__init__ (
793797 predictor = predictor ,
794798 points_per_side = points_per_side ,
795799 points_per_batch = points_per_batch ,
796800 point_grids = point_grids ,
801+ stability_score_offset = stability_score_offset ,
797802 )
798803
799804 @torch .no_grad ()
@@ -821,20 +826,24 @@ def initialize(
821826 embedding_save_path: Where to save the image embeddings.
822827 """
823828 original_size = image .shape [:2 ]
829+ self ._original_size = original_size
830+
824831 image_embeddings , tile_shape , halo = _compute_tiled_embeddings (
825832 self ._predictor , image , image_embeddings , embedding_save_path , tile_shape , halo
826833 )
827834
828835 tiling = blocking ([0 , 0 ], original_size , tile_shape )
829836 n_tiles = tiling .numberOfBlocks
830837
838+ # the crop box is always the full local tile
839+ tiles = [tiling .getBlockWithHalo (tile_id , list (halo )).outerBlock for tile_id in range (n_tiles )]
840+ crop_boxes = [[tile .begin [1 ], tile .begin [0 ], tile .end [1 ], tile .end [0 ]] for tile in tiles ]
841+
842+ # we need to cast to the image representation that is compatible with SAM
843+ image = util ._to_image (image )
844+
831845 mask_data = []
832846 for tile_id in tqdm (range (n_tiles ), total = n_tiles , desc = "Compute masks for tile" , disable = not verbose ):
833- # get the bounding box for this tile and crop the image data
834- tile = tiling .getBlockWithHalo (tile_id , list (halo )).outerBlock
835- tile_bb = tuple (slice (beg , end ) for beg , end in zip (tile .begin , tile .end ))
836- tile_data = image [tile_bb ]
837-
838847 # set the pre-computed embeddings for this tile
839848 features = image_embeddings ["features" ][tile_id ]
840849 tile_embeddings = {
@@ -846,18 +855,14 @@ def initialize(
846855
847856 # compute the mask data for this tile and append it
848857 this_mask_data = self ._process_crop (
849- tile_data , crop_box = None , crop_layer_idx = 0 , verbose = verbose , precomputed_embeddings = True
858+ image , crop_box = crop_boxes [ tile_id ] , crop_layer_idx = 0 , verbose = verbose , precomputed_embeddings = True
850859 )
851860 mask_data .append (this_mask_data )
852861
853862 # set the initialized data
854863 self ._is_initialized = True
855864 self ._crop_list = mask_data
856- self ._original_size = original_size
857-
858- # the crop box is always the full local tile
859- tiles = [tiling .getBlockWithHalo (tile_id , list (halo )).outerBlock for tile_id in range (n_tiles )]
860- self ._crop_boxes = [[tile .begin [1 ], tile .begin [0 ], tile .end [1 ], tile .end [0 ]] for tile in tiles ]
865+ self ._crop_boxes = crop_boxes
861866
862867
863868class TiledEmbeddingMaskGenerator (EmbeddingMaskGenerator ):
@@ -924,7 +929,10 @@ def _compute_mask_data_tiled(self, image_embeddings, i, initial_segmentations, n
924929 "original_size" : this_tile_shape
925930 }
926931 util .set_precomputed (self ._predictor , tile_image_embeddings , i )
927- tile_data = self ._compute_mask_data (initial_segmentations [tile_id ], this_tile_shape , verbose = False )
932+ this_crop_box = [0 , 0 , this_tile_shape [1 ], this_tile_shape [0 ]]
933+ tile_data = self ._compute_mask_data (
934+ initial_segmentations [tile_id ], this_crop_box , this_tile_shape , verbose = False
935+ )
928936 mask_data .append (tile_data )
929937
930938 return mask_data
@@ -982,7 +990,6 @@ def generate(
982990 self ,
983991 pred_iou_thresh : float = 0.88 ,
984992 stability_score_thresh : float = 0.95 ,
985- stability_score_offset : float = 1.0 ,
986993 box_nms_thresh : float = 0.7 ,
987994 min_mask_region_area : int = 0 ,
988995 verbose : bool = False
@@ -993,7 +1000,6 @@ def generate(
9931000 pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
9941001 stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
9951002 under changes to the cutoff used to binarize the model prediction.
996- stability_score_offset: The amount to shift the cutoff when calculating the stability score.
9971003 box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
9981004 min_mask_region_area: Minimal size for the predicted masks.
9991005 verbose: Whether to print progress of the computation.
@@ -1014,7 +1020,6 @@ def segment_tile(_, tile_id):
10141020 data = mask_data , crop_box = crop_box , original_size = this_tile_shape ,
10151021 pred_iou_thresh = pred_iou_thresh ,
10161022 stability_score_thresh = stability_score_thresh ,
1017- stability_score_offset = stability_score_offset ,
10181023 box_nms_thresh = box_nms_thresh ,
10191024 )
10201025 mask_data .to_numpy ()
0 commit comments