@@ -135,14 +135,14 @@ def _postprocess_batch(
135135
136136 # calculate stability score
137137 data ["stability_score" ] = amg_utils .calculate_stability_score (
138- data ["masks" ], self .predictor .model .mask_threshold , stability_score_offset
138+ data ["masks" ], self ._predictor .model .mask_threshold , stability_score_offset
139139 )
140140 if stability_score_thresh > 0.0 :
141141 keep_mask = data ["stability_score" ] >= stability_score_thresh
142142 data .filter (keep_mask )
143143
144144 # threshold masks and calculate boxes
145- data ["masks" ] = data ["masks" ] > self .predictor .model .mask_threshold
145+ data ["masks" ] = data ["masks" ] > self ._predictor .model .mask_threshold
146146 data ["boxes" ] = amg_utils .batched_mask_to_box (data ["masks" ])
147147
148148 # filter boxes that touch crop boundaries
@@ -327,19 +327,19 @@ def __init__(
327327 else :
328328 raise ValueError ("Can't have both points_per_side and point_grid be None or not None." )
329329
330- self .predictor = predictor
331- self .points_per_side = points_per_side
332- self .points_per_batch = points_per_batch
333- self .crop_n_layers = crop_n_layers
334- self .crop_overlap_ratio = crop_overlap_ratio
335- self .crop_n_points_downscale_factor = crop_n_points_downscale_factor
330+ self ._predictor = predictor
331+ self ._points_per_side = points_per_side
332+ self ._points_per_batch = points_per_batch
333+ self ._crop_n_layers = crop_n_layers
334+ self ._crop_overlap_ratio = crop_overlap_ratio
335+ self ._crop_n_points_downscale_factor = crop_n_points_downscale_factor
336336
337337 def _process_batch (self , points , im_size ):
338338 # run model on this batch
339- transformed_points = self .predictor .transform .apply_coords (points , im_size )
340- in_points = torch .as_tensor (transformed_points , device = self .predictor .device )
339+ transformed_points = self ._predictor .transform .apply_coords (points , im_size )
340+ in_points = torch .as_tensor (transformed_points , device = self ._predictor .device )
341341 in_labels = torch .ones (in_points .shape [0 ], dtype = torch .int , device = in_points .device )
342- masks , iou_preds , _ = self .predictor .predict_torch (
342+ masks , iou_preds , _ = self ._predictor .predict_torch (
343343 in_points [:, None , :],
344344 in_labels [:, None ],
345345 multimask_output = True ,
@@ -363,18 +363,18 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
363363 cropped_im_size = cropped_im .shape [:2 ]
364364
365365 if not precomputed_embeddings :
366- self .predictor .set_image (cropped_im )
366+ self ._predictor .set_image (cropped_im )
367367
368368 # get the points for this crop
369369 points_scale = np .array (cropped_im_size )[None , ::- 1 ]
370370 points_for_image = self .point_grids [crop_layer_idx ] * points_scale
371371
372372 # generate masks for this crop in batches
373373 data = amg_utils .MaskData ()
374- n_batches = len (points_for_image ) // self .points_per_batch + \
375- int (len (points_for_image ) % self .points_per_batch != 0 )
374+ n_batches = len (points_for_image ) // self ._points_per_batch + \
375+ int (len (points_for_image ) % self ._points_per_batch != 0 )
376376 for (points ,) in tqdm (
377- amg_utils .batch_iterator (self .points_per_batch , points_for_image ),
377+ amg_utils .batch_iterator (self ._points_per_batch , points_for_image ),
378378 disable = not verbose , total = n_batches ,
379379 desc = "Predict masks for point grid prompts" ,
380380 ):
@@ -383,7 +383,7 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
383383 del batch_data
384384
385385 if not precomputed_embeddings :
386- self .predictor .reset_image ()
386+ self ._predictor .reset_image ()
387387
388388 return data
389389
@@ -407,16 +407,16 @@ def initialize(
407407 """
408408 original_size = image .shape [:2 ]
409409 crop_boxes , layer_idxs = amg_utils .generate_crop_boxes (
410- original_size , self .crop_n_layers , self .crop_overlap_ratio
410+ original_size , self ._crop_n_layers , self ._crop_overlap_ratio
411411 )
412412
413413 # we can set fixed image embeddings if we only have a single crop box
414414 # (which is the default setting)
415415 # otherwise we have to recompute the embeddings for each crop and can't precompute
416416 if len (crop_boxes ) == 1 :
417417 if image_embeddings is None :
418- image_embeddings = util .precompute_image_embeddings (self .predictor , image )
419- util .set_precomputed (self .predictor , image_embeddings , i = i )
418+ image_embeddings = util .precompute_image_embeddings (self ._predictor , image )
419+ util .set_precomputed (self ._predictor , image_embeddings , i = i )
420420 precomputed_embeddings = True
421421 else :
422422 precomputed_embeddings = False
@@ -537,33 +537,33 @@ def __init__(
537537 ):
538538 super ().__init__ ()
539539
540- self .predictor = predictor
541- self .offsets = self .default_offsets if offsets is None else offsets
542- self .min_initial_size = min_initial_size
543- self .distance_type = distance_type
544- self .bias = bias
545- self .use_box = use_box
546- self .use_mask = use_mask
547- self .use_points = use_points
548- self .box_extension = box_extension
540+ self ._predictor = predictor
541+ self ._offsets = self .default_offsets if offsets is None else offsets
542+ self ._min_initial_size = min_initial_size
543+ self ._distance_type = distance_type
544+ self ._bias = bias
545+ self ._use_box = use_box
546+ self ._use_mask = use_mask
547+ self ._use_points = use_points
548+ self ._box_extension = box_extension
549549
550550 # additional state that is set 'initialize'
551551 self ._initial_segmentation = None
552552
553553 def _compute_initial_segmentation (self ):
554554
555- embeddings = self .predictor .get_image_embedding ().squeeze ().cpu ().numpy ()
555+ embeddings = self ._predictor .get_image_embedding ().squeeze ().cpu ().numpy ()
556556 assert embeddings .shape == (256 , 64 , 64 ), f"{ embeddings .shape } "
557557
558558 initial_segmentation = embed .segment_embeddings_mws (
559- embeddings , distance_type = self .distance_type , offsets = self .offsets , bias = self .bias ,
559+ embeddings , distance_type = self ._distance_type , offsets = self ._offsets , bias = self ._bias ,
560560 ).astype ("uint32" )
561561 assert initial_segmentation .shape == (64 , 64 ), f"{ initial_segmentation .shape } "
562562
563563 # filter out small initial objects
564- if self .min_initial_size > 0 :
564+ if self ._min_initial_size > 0 :
565565 seg_ids , sizes = np .unique (initial_segmentation , return_counts = True )
566- initial_segmentation [np .isin (initial_segmentation , seg_ids [sizes < self .min_initial_size ])] = 0
566+ initial_segmentation [np .isin (initial_segmentation , seg_ids [sizes < self ._min_initial_size ])] = 0
567567
568568 # resize to 256 x 256, which is the mask input expected by SAM
569569 initial_segmentation = resize (
@@ -582,10 +582,10 @@ def _compute_mask_data(self, initial_segmentation, original_size, verbose):
582582 for seg_id in tqdm (seg_ids , disable = not verbose , desc = "Compute masks from initial segmentation" ):
583583 mask = initial_segmentation == seg_id
584584 masks , iou_preds , _ = segment_from_mask (
585- self .predictor , mask , original_size = original_size ,
585+ self ._predictor , mask , original_size = original_size ,
586586 multimask_output = True , return_logits = True , return_all = True ,
587- use_box = self .use_box , use_mask = self .use_mask , use_points = self .use_points ,
588- box_extension = self .box_extension ,
587+ use_box = self ._use_box , use_mask = self ._use_mask , use_points = self ._use_points ,
588+ box_extension = self ._box_extension ,
589589 )
590590 data = amg_utils .MaskData (
591591 masks = torch .from_numpy (masks ),
@@ -618,8 +618,8 @@ def initialize(
618618 original_size = image .shape [:2 ]
619619
620620 if image_embeddings is None :
621- image_embeddings = util .precompute_image_embeddings (self .predictor , image ,)
622- util .set_precomputed (self .predictor , image_embeddings , i = i )
621+ image_embeddings = util .precompute_image_embeddings (self ._predictor , image ,)
622+ util .set_precomputed (self ._predictor , image_embeddings , i = i )
623623
624624 # compute the initial segmentation via embedding based MWS and then refine the masks
625625 # with the segment anything model
@@ -737,8 +737,8 @@ def __init__(
737737 ** kwargs
738738 ):
739739 super ().__init__ (predictor = predictor , ** kwargs )
740- self .n_threads = n_threads
741- self .with_background = with_background
740+ self ._n_threads = n_threads
741+ self ._with_background = with_background
742742
743743 # additional state for 'initialize'
744744 self ._tile_shape = None
@@ -758,10 +758,10 @@ def segment_tile(tile_id):
758758 "input_size" : tile_features .attrs ["input_size" ],
759759 "original_size" : tile_features .attrs ["original_size" ]
760760 }
761- util .set_precomputed (self .predictor , tile_image_embeddings , i )
761+ util .set_precomputed (self ._predictor , tile_image_embeddings , i )
762762 return self ._compute_initial_segmentation ()
763763
764- with futures .ThreadPoolExecutor (self .n_threads ) as tp :
764+ with futures .ThreadPoolExecutor (self ._n_threads ) as tp :
765765 initial_segmentations = list (tqdm (
766766 tp .map (segment_tile , range (n_tiles )), disable = not verbose , total = n_tiles ,
767767 desc = "Tile-based initial segmentation"
@@ -781,7 +781,7 @@ def _compute_mask_data_tiled(self, image_embeddings, i, initial_segmentations, n
781781 "input_size" : tile_features .attrs ["input_size" ],
782782 "original_size" : this_tile_shape
783783 }
784- util .set_precomputed (self .predictor , tile_image_embeddings , i )
784+ util .set_precomputed (self ._predictor , tile_image_embeddings , i )
785785 tile_data = self ._compute_mask_data (initial_segmentations [tile_id ], this_tile_shape , verbose = False )
786786 mask_data .append (tile_data )
787787
@@ -821,7 +821,7 @@ def initialize(
821821 "Embeddings with tiling can only be computed if a save path is given."
822822 )
823823 image_embeddings = util .precompute_image_embeddings (
824- self .predictor , image , tile_shape = tile_shape , halo = halo , save_path = embedding_save_path
824+ self ._predictor , image , tile_shape = tile_shape , halo = halo , save_path = embedding_save_path
825825 )
826826 elif image_embeddings is None and not have_tiling_params :
827827 raise ValueError ("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)" )
@@ -903,12 +903,12 @@ def segment_tile(_, tile_id):
903903 mask_data = self ._postprocess_masks (
904904 mask_data , 0 , box_nms_thresh , box_nms_thresh , output_mode = "binary_mask"
905905 )
906- mask_data = mask_data_to_segmentation (mask_data , this_tile_shape , with_background = self .with_background )
906+ mask_data = mask_data_to_segmentation (mask_data , this_tile_shape , with_background = self ._with_background )
907907 return mask_data
908908
909909 input_ = _FakeInput (self .original_size )
910910 segmentation = stitch_segmentation (
911- input_ , segment_tile , self ._tile_shape , self ._halo , with_background = self .with_background , verbose = verbose
911+ input_ , segment_tile , self ._tile_shape , self ._halo , with_background = self ._with_background , verbose = verbose
912912 )
913913
914914 if min_mask_region_area > 0 :
@@ -940,7 +940,7 @@ def segment_tile(_, tile_id):
940940 initial_segmentation = stitch_segmentation (
941941 input_ , segment_tile ,
942942 self ._tile_shape , self ._halo ,
943- with_background = self .with_background , verbose = False
943+ with_background = self ._with_background , verbose = False
944944 )
945945
946946 self ._stitched_initial_segmentation = initial_segmentation
0 commit comments