@@ -84,8 +84,18 @@ def get_sam_model(device=None, model_type="vit_h", checkpoint_path=None, return_
8484 return predictor
8585
8686
87+ def _to_image (input_ ):
88+ if input_ .ndim == 2 :
89+ image = np .concatenate ([input_ [..., None ]] * 3 , axis = - 1 )
90+ elif input_ .ndim == 3 and input_ .shape [- 1 ] == 3 :
91+ image = input_
92+ else :
93+ raise ValueError (f"Invalid input image of shape { input_ .shape } . Expect either 2D grayscale or 3D RGB image." )
94+ return image
95+
96+
8797def _compute_2d (input_ , predictor ):
88- image = np . concatenate ([ input_ [..., None ]] * 3 , axis = - 1 )
98+ image = _to_image ( input_ )
8999 predictor .set_image (image )
90100 features = predictor .get_image_embedding ()
91101 original_size = predictor .original_size
@@ -103,7 +113,7 @@ def _precompute_2d(input_, predictor, save_path):
103113 features = f ["features" ][:]
104114 original_size , input_size = f .attrs ["original_size" ], f .attrs ["input_size" ]
105115 else :
106- image = np . concatenate ([ input_ [..., None ]] * 3 , axis = - 1 )
116+ image = _to_image ( input_ )
107117 predictor .set_image (image )
108118 features = predictor .get_image_embedding ()
109119 original_size , input_size = predictor .original_size , predictor .input_size
@@ -186,7 +196,7 @@ def _precompute_3d(input_, predictor, save_path, lazy_loading):
186196 return image_embeddings
187197
188198
189- def precompute_image_embeddings (predictor , input_ , save_path = None , lazy_loading = False ):
199+ def precompute_image_embeddings (predictor , input_ , save_path = None , lazy_loading = False , ndim = None ):
190200 """Compute the image embeddings (output of the encoder) for the input.
191201
192202 If save_path is given the embeddings will be loaded/saved in a zarr container.
@@ -198,13 +208,15 @@ def precompute_image_embeddings(predictor, input_, save_path=None, lazy_loading=
198208 lazy_loading [bool] - whether to load all embeddings into memory or return an
199209 object to load them on demand when required. This only has an effect if 'save_path'
200210 is given and if the input is 3D. (default: False)
211+ ndim [int] - the dimensionality of the data. If not given will be deduced from the input data. (default: None)
201212 """
202213
203- if input_ .ndim == 2 :
214+ ndim = input_ .ndim if ndim is None else ndim
215+ if ndim == 2 :
204216 image_embeddings = _compute_2d (input_ , predictor ) if save_path is None else \
205217 _precompute_2d (input_ , predictor , save_path )
206218
207- elif input_ . ndim == 3 :
219+ elif ndim == 3 :
208220 image_embeddings = _compute_3d (input_ , predictor ) if save_path is None else \
209221 _precompute_3d (input_ , predictor , save_path , lazy_loading )
210222
0 commit comments