@@ -93,8 +93,9 @@ def cache_is_state(
9393 save_path : Union [str , os .PathLike ],
9494 verbose : bool = True ,
9595 i : Optional [int ] = None ,
96+ skip_load : bool = False ,
9697 ** kwargs ,
97- ) -> instance_segmentation .AMGBase :
98+ ) -> Optional [ instance_segmentation .AMGBase ] :
9899 """Compute and cache or load the state for the automatic mask generator.
99100
100101 Args:
@@ -105,6 +106,7 @@ def cache_is_state(
105106 save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
106107 verbose: Whether to run the computation verbose.
107108 i: The index for which to cache the state.
109+ skip_load: Skip loading the state if it is precomputed.
108110 kwargs: The keyword arguments for the amg class.
109111
110112 Returns:
@@ -120,6 +122,9 @@ def cache_is_state(
120122
121123 with h5py .File (save_path , "a" ) as f :
122124 if save_key in f :
125+ if skip_load : # Skip loading to speed this up for cases where we don't need the return val.
126+ return
127+
123128 if verbose :
124129 print ("Load instance segmentation state from" , save_path , ":" , save_key )
125130 g = f [save_key ]
@@ -169,6 +174,7 @@ def _precompute_state_for_files(
169174 predictor : SamPredictor ,
170175 input_files : Union [List [Union [os .PathLike , str ]], List [np .ndarray ]],
171176 output_path : Union [os .PathLike , str ],
177+ key : Optional [str ] = None ,
172178 ndim : Optional [int ] = None ,
173179 tile_shape : Optional [Tuple [int , int ]] = None ,
174180 halo : Optional [Tuple [int , int ]] = None ,
@@ -185,14 +191,15 @@ def _precompute_state_for_files(
185191
186192 _precompute_state_for_file (
187193 predictor , file_path , out_path ,
188- key = None , ndim = ndim , tile_shape = tile_shape , halo = halo ,
194+ key = key , ndim = ndim , tile_shape = tile_shape , halo = halo ,
189195 precompute_amg_state = precompute_amg_state , decoder = decoder ,
190196 )
191197
192198
193199def precompute_state (
194200 input_path : Union [os .PathLike , str ],
195201 output_path : Union [os .PathLike , str ],
202+ pattern : Optional [str ] = None ,
196203 model_type : str = util ._DEFAULT_MODEL ,
197204 checkpoint_path : Optional [Union [os .PathLike , str ]] = None ,
198205 key : Optional [str ] = None ,
@@ -209,31 +216,41 @@ def precompute_state(
209216 In case of a container file the argument `key` must be given. In case of a folder
210217 it can be given to provide a glob pattern to subselect files from the folder.
211218 output_path: The output path were the embeddings and other state will be saved.
219+ pattern: Glob pattern to select files in a folder. The embeddings will be computed
220+ for each of these files. To select all files in a folder pass "*".
212221 model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
213222 checkpoint_path: Path to a checkpoint for a custom model.
214223 key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
215- and can be used to provide a glob pattern if the input is a folder with image files .
224+ or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case .
216225 ndim: The dimensionality of the data.
217226 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
218227 halo: Overlap of the tiles for tiled prediction.
219228 precompute_amg_state: Whether to precompute the state for automatic instance segmentation
220229 in addition to the image embeddings.
221230 """
222- predictor = util .get_sam_model (model_type = model_type , checkpoint_path = checkpoint_path )
223- # check if we precompute the state for a single file or for a folder with image files
224- if os .path .isdir (input_path ) and Path (input_path ).suffix not in (".n5" , ".zarr" ):
225- pattern = "*" if key is None else key
226- input_files = glob (os .path .join (input_path , pattern ))
227- _precompute_state_for_files (
228- predictor , input_files , output_path ,
231+ predictor , state = util .get_sam_model (
232+ model_type = model_type , checkpoint_path = checkpoint_path , return_state = True ,
233+ )
234+ if "decoder_state" in state :
235+ decoder = instance_segmentation .get_decoder (predictor .model .image_encoder , state ["decoder_state" ])
236+ else :
237+ decoder = None
238+
239+ # Check if we precompute the state for a single file or for a folder with image files.
240+ if pattern is None :
241+ _precompute_state_for_file (
242+ predictor , input_path , output_path , key ,
229243 ndim = ndim , tile_shape = tile_shape , halo = halo ,
230244 precompute_amg_state = precompute_amg_state ,
245+ decoder = decoder ,
231246 )
232247 else :
233- _precompute_state_for_file (
234- predictor , input_path , output_path , key ,
248+ input_files = glob (os .path .join (input_path , pattern ))
249+ _precompute_state_for_files (
250+ predictor , input_files , output_path , key = key ,
235251 ndim = ndim , tile_shape = tile_shape , halo = halo ,
236252 precompute_amg_state = precompute_amg_state ,
253+ decoder = decoder ,
237254 )
238255
239256
@@ -253,11 +270,16 @@ def main():
253270 parser .add_argument (
254271 "-e" , "--embedding_path" , required = True , help = "The path where the embeddings will be saved."
255272 )
273+
274+ parser .add_argument (
275+ "--pattern" , help = "Pattern / wildcard for selecting files in a folder. To select all files use '*'."
276+ )
256277 parser .add_argument (
257278 "-k" , "--key" ,
258279 help = "The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
259- "for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
280+ "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
260281 )
282+
261283 parser .add_argument (
262284 "-m" , "--model_type" , default = util ._DEFAULT_MODEL ,
263285 help = f"The segment anything model that will be used, one of { available_models } ."
@@ -284,8 +306,10 @@ def main():
284306
285307 args = parser .parse_args ()
286308 precompute_state (
287- args .input_path , args .embedding_path , args .model_type , args .checkpoint ,
288- key = args .key , tile_shape = args .tile_shape , halo = args .halo , ndim = args .ndim ,
309+ args .input_path , args .embedding_path ,
310+ model_type = args .model_type , checkpoint_path = args .checkpoint ,
311+ pattern = args .pattern , key = args .key ,
312+ tile_shape = args .tile_shape , halo = args .halo , ndim = args .ndim ,
289313 precompute_amg_state = args .precompute_amg_state ,
290314 )
291315
0 commit comments