@@ -74,6 +74,7 @@ def automatic_instance_segmentation(
7474 tile_shape : Optional [Tuple [int , int ]] = None ,
7575 halo : Optional [Tuple [int , int ]] = None ,
7676 verbose : bool = True ,
77+ return_embeddings : bool = False ,
7778 ** generate_kwargs
7879) -> np .ndarray :
7980 """Run automatic segmentation for the input image.
@@ -92,6 +93,7 @@ def automatic_instance_segmentation(
9293 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
9394 halo: Overlap of the tiles for tiled prediction.
9495 verbose: Verbosity flag.
96+ return_embeddings: Whether to return the precomputed image embeddings.
9597 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
9698
9799 Returns:
@@ -142,23 +144,32 @@ def automatic_instance_segmentation(
142144 if (image_data .ndim != 3 ) and (image_data .ndim != 4 and image_data .shape [- 1 ] != 3 ):
143145 raise ValueError (f"The inputs does not match the shape expectation of 3d inputs: { image_data .shape } " )
144146
145- instances = automatic_3d_segmentation (
147+ outputs = automatic_3d_segmentation (
146148 volume = image_data ,
147149 predictor = predictor ,
148150 segmentor = segmenter ,
149151 embedding_path = embedding_path ,
150152 tile_shape = tile_shape ,
151153 halo = halo ,
152154 verbose = verbose ,
155+ return_embeddings = return_embeddings ,
153156 ** generate_kwargs
154157 )
155158
159+ if return_embeddings :
160+ instances , image_embeddings = outputs
161+ else :
162+ instances = outputs
163+
164+ # Save the instance segmentation, if 'output_path' provided.
156165 if output_path is not None :
157- # Save the instance segmentation
158166 output_path = Path (output_path ).with_suffix (".tif" )
159167 imageio .imwrite (output_path , instances , compression = "zlib" )
160168
161- return instances
169+ if return_embeddings :
170+ return instances , image_embeddings
171+ else :
172+ return instances
162173
163174
164175def main ():
@@ -194,8 +205,7 @@ def main():
194205 help = f"The segment anything model that will be used, one of { available_models } ."
195206 )
196207 parser .add_argument (
197- "-c" , "--checkpoint" , default = None ,
198- help = "Checkpoint from which the SAM model will be loaded."
208+ "-c" , "--checkpoint" , default = None , help = "Checkpoint from which the SAM model will be loaded."
199209 )
200210 parser .add_argument (
201211 "--tile_shape" , nargs = "+" , type = int , help = "The tile shape for using tiled prediction." , default = None
0 commit comments