@@ -320,7 +320,13 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
320320 return data
321321
322322 @torch .no_grad ()
323- def initialize (self , image : np .ndarray , image_embeddings = None , i = None , verbose = False ):
323+ def initialize (
324+ self ,
325+ image : np .ndarray ,
326+ image_embeddings = None ,
327+ i : Optional [int ] = None ,
328+ verbose : bool = False
329+ ):
324330 """
325331 """
326332 original_size = image .shape [:2 ]
@@ -403,14 +409,14 @@ class EmbeddingMaskGenerator(_AMGBase):
403409 def __init__ (
404410 self ,
405411 predictor : SamPredictor ,
406- offsets = None ,
407- min_initial_size = 0 ,
408- distance_type = "l2" ,
409- bias = 0.0 ,
410- use_box = True ,
411- use_mask = True ,
412- use_points = False ,
413- box_extension = 0.05 ,
412+ offsets : Optional [ List [ List [ int ]]] = None ,
413+ min_initial_size : int = 0 ,
414+ distance_type : str = "l2" ,
415+ bias : float = 0.0 ,
416+ use_box : bool = True ,
417+ use_mask : bool = True ,
418+ use_points : bool = False ,
419+ box_extension : float = 0.05 ,
414420 ):
415421 super ().__init__ ()
416422
@@ -475,7 +481,13 @@ def _compute_mask_data(self, initial_segmentation, original_size, verbose):
475481 return mask_data
476482
477483 @torch .no_grad ()
478- def initialize (self , image : np .ndarray , image_embeddings = None , i = None , verbose = False ):
484+ def initialize (
485+ self ,
486+ image : np .ndarray ,
487+ image_embeddings = None ,
488+ i : Optional [int ] = None ,
489+ verbose : bool = False
490+ ):
479491 """
480492 """
481493 original_size = image .shape [:2 ]
@@ -545,10 +557,17 @@ def get_initial_segmentation(self):
545557class TiledEmbeddingMaskGenerator (EmbeddingMaskGenerator ):
546558 """
547559 """
548- def __init__ (self , n_threads = mp .cpu_count (), with_background = True , ** kwargs ):
549- super ().__init__ (** kwargs )
560+ def __init__ (
561+ self ,
562+ predictor : SamPredictor ,
563+ n_threads : int = mp .cpu_count (),
564+ with_background : bool = True ,
565+ ** kwargs
566+ ):
567+ super ().__init__ (predictor = predictor , ** kwargs )
550568 self .n_threads = n_threads
551569 self .with_background = with_background
570+ # additional data for 'initialize'
552571 self ._tile_shape = None
553572 self ._halo = None
554573 self ._stitched_initial_segmentation = None
@@ -599,14 +618,17 @@ def initialize(
599618 tile_shape : List [int ],
600619 halo : List [int ],
601620 image_embeddings = None ,
602- i = None ,
603- verbose = False ,
621+ i : Optional [int ] = None ,
622+ verbose : bool = False ,
623+ embedding_save_path : Optional [str ] = None ,
604624 ):
605625 """
606626 """
607627 original_size = image .shape [:2 ]
608628 if image_embeddings is None :
609- image_embeddings = util .precompute_image_embeddings (self .predictor , image , tile_shape = tile_shape , halo = halo )
629+ image_embeddings = util .precompute_image_embeddings (
630+ self .predictor , image , tile_shape = tile_shape , halo = halo , save_path = embedding_save_path
631+ )
610632
611633 tiling = blocking ([0 , 0 ], original_size , tile_shape )
612634 n_tiles = tiling .numberOfBlocks
0 commit comments