@@ -288,12 +288,14 @@ def _process_batch(self, points, im_size):
288288
289289 return data
290290
291- def _process_crop (self , image , crop_box , crop_layer_idx , verbose ):
291+ def _process_crop (self , image , crop_box , crop_layer_idx , verbose , precomputed_embeddings ):
292292 # crop the image and calculate embeddings
293293 x0 , y0 , x1 , y1 = crop_box
294294 cropped_im = image [y0 :y1 , x0 :x1 , :]
295295 cropped_im_size = cropped_im .shape [:2 ]
296- self .predictor .set_image (cropped_im )
296+
297+ if not precomputed_embeddings :
298+ self .predictor .set_image (cropped_im )
297299
298300 # get the points for this crop
299301 points_scale = np .array (cropped_im_size )[None , ::- 1 ]
@@ -312,23 +314,39 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose):
312314 data .cat (batch_data )
313315 del batch_data
314316
315- self .predictor .reset_image ()
317+ if not precomputed_embeddings :
318+ self .predictor .reset_image ()
319+
316320 return data
317321
318- # TODO enable initializeing with embeddings
319- # (which can be done for only a single crop box)
320322 @torch .no_grad ()
321- def initialize (self , image : np .ndarray , verbose = False ):
323+ def initialize (self , image : np .ndarray , image_embeddings = None , i = None , embedding_path = None , verbose = False ):
322324 """
323325 """
324- image = util ._to_image (image )
325326 original_size = image .shape [:2 ]
326327 crop_boxes , layer_idxs = amg_utils .generate_crop_boxes (
327328 original_size , self .crop_n_layers , self .crop_overlap_ratio
328329 )
330+
331+ # we can set fixed image embeddings if we only have a single crop box
332+ # (which is the default setting)
333+ # otherwise we have to recompute the embeddings for each crop and can't precompute
334+ if len (crop_boxes ) == 1 :
335+ if image_embeddings is None :
336+ image_embeddings = util .precompute_image_embeddings (self .predictor , image , save_path = embedding_path )
337+ util .set_precomputed (self .predictor , image_embeddings , i = i )
338+ precomputed_embeddings = True
339+ else :
340+ precomputed_embeddings = False
341+
342+ # we need to cast to the image representation that is compatible with SAM
343+ image = util ._to_image (image )
344+
329345 crop_list = []
330346 for crop_box , layer_idx in zip (crop_boxes , layer_idxs ):
331- crop_data = self ._process_crop (image , crop_box , layer_idx , verbose = verbose )
347+ crop_data = self ._process_crop (
348+ image , crop_box , layer_idx , verbose = verbose , precomputed_embeddings = precomputed_embeddings
349+ )
332350 crop_list .append (crop_data )
333351
334352 self ._is_initialized = True
0 commit comments