Skip to content

Commit 343361f

Browse files
2 parents 1a134ec + c61896e commit 343361f

File tree

4 files changed

+234
-79
lines changed

4 files changed

+234
-79
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
from glob import glob
3+
4+
from micro_sam.evaluation.inference import run_inference_with_iterative_prompting
5+
from micro_sam.evaluation.evaluation import run_evaluation
6+
7+
from util import get_checkpoint, get_paths
8+
9+
LIVECELL_GT_ROOT = "/scratch-grete/projects/nim00007/data/LiveCELL/annotations_corrected/livecell_test_images"
10+
# TODO update to make fit other models
11+
PREDICTION_ROOT = "./pred_interactive_prompting"
12+
13+
14+
def run_interactive_prompting():
15+
prediction_root = PREDICTION_ROOT
16+
17+
checkpoint, model_type = get_checkpoint("vit_b")
18+
image_paths, gt_paths = get_paths()
19+
20+
run_inference_with_iterative_prompting(
21+
checkpoint, model_type, image_paths, gt_paths,
22+
prediction_root, use_boxes=False, batch_size=16,
23+
)
24+
25+
26+
def get_pg_paths(pred_folder):
27+
pred_paths = sorted(glob(os.path.join(pred_folder, "*.tif")))
28+
names = [os.path.split(path)[1] for path in pred_paths]
29+
gt_paths = [
30+
os.path.join(LIVECELL_GT_ROOT, name.split("_")[0], name) for name in names
31+
]
32+
assert all(os.path.exists(pp) for pp in gt_paths)
33+
return pred_paths, gt_paths
34+
35+
36+
def evaluate_interactive_prompting():
37+
prediction_root = PREDICTION_ROOT
38+
prediction_folders = sorted(glob(os.path.join(prediction_root, "iteration*")))
39+
for pred_folder in prediction_folders:
40+
print("Evaluating", pred_folder)
41+
pred_paths, gt_paths = get_pg_paths(pred_folder)
42+
res = run_evaluation(gt_paths, pred_paths, save_path=None)
43+
print(res)
44+
45+
46+
def main():
47+
# run_interactive_prompting()
48+
evaluate_interactive_prompting()
49+
50+
51+
if __name__ == "__main__":
52+
main()

micro_sam/evaluation/inference.py

Lines changed: 116 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import pickle
3+
import warnings
34

45
from copy import deepcopy
56
from typing import Any, Dict, List, Optional, Union
@@ -15,8 +16,9 @@
1516
from segment_anything.utils.transforms import ResizeLongestSide
1617

1718
from .. import util as util
18-
from ..training import get_trainable_sam_model, ConvertToSamInputs
19+
from ..instance_segmentation import mask_data_to_segmentation
1920
from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator
21+
from ..training import get_trainable_sam_model, ConvertToSamInputs
2022

2123

2224
def _load_prompts(
@@ -422,48 +424,127 @@ def run_inference_with_prompts(
422424
pickle.dump(cached_box_prompts, f)
423425

424426

425-
def run_inference_with_iterative_prompting(
426-
image, gt, model_type, checkpoint_path, n_iterations, n_positive, n_negative,
427-
use_boxes, device=None, _sigmoid=torch.nn.Sigmoid()
427+
def _save_segmentation(masks, prediction_path):
428+
# masks to segmentation
429+
masks = masks.cpu().numpy().squeeze().astype("bool")
430+
shape = masks.shape[-2:]
431+
masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks]
432+
segmentation = mask_data_to_segmentation(masks, shape, with_background=True)
433+
imageio.imwrite(prediction_path, segmentation)
434+
435+
436+
def _run_inference_with_iterative_prompting_for_image(
437+
model,
438+
image,
439+
gt,
440+
n_iterations,
441+
device,
442+
use_boxes,
443+
prediction_paths,
444+
batch_size,
428445
):
429-
if device is None:
430-
device = "cuda" if torch.cuda.is_available() else "cpu"
446+
assert len(prediction_paths) == n_iterations, f"{len(prediction_paths)}, {n_iterations}"
447+
to_sam_inputs = ConvertToSamInputs()
431448

432-
model = get_trainable_sam_model(model_type, checkpoint_path)
433-
_to_sam_inputs = ConvertToSamInputs()
434-
batched_inputs, sampled_ids = _to_sam_inputs(image, gt, n_positive, n_negative, use_boxes)
435-
sampled_binary_y = [np.isin(gt, idx) for idx in sampled_ids]
449+
image = torch.from_numpy(
450+
image[None, None] if image.ndim == 2 else image[None]
451+
)
452+
gt = torch.from_numpy(gt[None].astype("int32"))
453+
454+
n_pos = 0 if use_boxes else 1
455+
batched_inputs, sampled_ids = to_sam_inputs(image, gt, n_pos=n_pos, n_neg=0, get_boxes=use_boxes)
436456

437457
input_images = torch.stack([model.preprocess(x=x["image"].to(device)) for x in batched_inputs], dim=0)
438458
image_embeddings = model.image_embeddings_oft(input_images)
439459

460+
multimasking = n_pos == 1
440461
prompt_generator = IterativePromptGenerator(device)
441462

442-
multimasking = False
443-
if n_positive == 1 and n_negative == 0:
444-
if not use_boxes:
445-
multimasking = True
463+
n_samples = len(sampled_ids[0])
464+
n_batches = int(np.ceil(float(n_samples) / batch_size))
446465

447466
for iteration in range(n_iterations):
448-
batched_outputs = model(
449-
batched_inputs,
450-
multimask_output=multimasking if iteration == 0 else False,
451-
image_embeddings=image_embeddings
452-
)
467+
final_masks = []
468+
for batch_idx in range(n_batches):
469+
batch_start = batch_idx * batch_size
470+
batch_stop = min((batch_idx + 1) * batch_size, n_samples)
471+
472+
this_batched_inputs = [{
473+
k: v[batch_start:batch_stop] if k in ("point_coords", "point_labels") else v
474+
for k, v in batched_inputs[0].items()
475+
}]
476+
477+
sampled_binary_y = torch.stack([
478+
torch.stack([_gt == idx for idx in sampled[batch_start:batch_stop]])[:, None]
479+
for _gt, sampled in zip(gt, sampled_ids)
480+
]).to(torch.float32)
481+
482+
batched_outputs = model(
483+
this_batched_inputs,
484+
multimask_output=multimasking if iteration == 0 else False,
485+
image_embeddings=image_embeddings
486+
)
487+
488+
masks, logits_masks = [], []
489+
for m in batched_outputs:
490+
mask, l_mask = [], []
491+
for _m, _l, _iou in zip(m["masks"], m["low_res_masks"], m["iou_predictions"]):
492+
best_iou_idx = torch.argmax(_iou)
493+
mask.append(torch.sigmoid(_m[best_iou_idx][None]))
494+
l_mask.append(_l[best_iou_idx][None])
495+
mask, l_mask = torch.stack(mask), torch.stack(l_mask)
496+
masks.append(mask)
497+
logits_masks.append(l_mask)
498+
499+
masks, logits_masks = torch.stack(masks), torch.stack(logits_masks)
500+
masks = (masks > 0.5).to(torch.float32)
501+
final_masks.append(masks)
502+
503+
for _pred, _gt, _inp, logits in zip(masks, sampled_binary_y, this_batched_inputs, logits_masks):
504+
next_coords, next_labels = prompt_generator(_gt, _pred, _inp["point_coords"], _inp["point_labels"])
505+
_inp["point_coords"], _inp["point_labels"], _inp["mask_inputs"] = next_coords, next_labels, logits
506+
507+
final_masks = torch.cat(final_masks, dim=1)
508+
_save_segmentation(final_masks, prediction_paths[iteration])
509+
453510

454-
masks, logits_masks = [], []
455-
for m in batched_outputs:
456-
mask, l_mask = [], []
457-
for _m, _l, _iou in zip(m["masks"], m["low_res_masks"], m["iou_predictions"]):
458-
best_iou_idx = torch.argmax(_iou)
459-
mask.append(_sigmoid(_m[best_iou_idx][None]))
460-
l_mask.append(_l[best_iou_idx][None])
461-
mask, l_mask = torch.stack(mask), torch.stack(l_mask)
462-
masks.append(mask)
463-
logits_masks.append(l_mask)
464-
masks, logits_masks = torch.stack(masks), torch.stack(logits_masks)
465-
masks = (masks > 0.5).to(torch.float32)
466-
467-
for _pred, _gt, _inp, logits in zip(masks, sampled_binary_y, batched_inputs, logits_masks):
468-
net_coords, net_labels = prompt_generator(_gt, _pred, _inp["point_coords"], _inp["point_labels"])
469-
_inp["point_coords"], _inp["point_labels"], _inp["mask_inputs"] = net_coords, net_labels, logits
511+
def run_inference_with_iterative_prompting(
512+
checkpoint_path: Union[str, os.PathLike],
513+
model_type: str,
514+
image_paths: List[Union[str, os.PathLike]],
515+
gt_paths: List[Union[str, os.PathLike]],
516+
prediction_root: Union[str, os.PathLike],
517+
use_boxes: bool,
518+
n_iterations: int = 8,
519+
batch_size: int = 32,
520+
) -> None:
521+
"""@private"""
522+
warnings.warn("The iterative prompting functionality is not working correctly yet.")
523+
524+
device = torch.device("cuda")
525+
model = get_trainable_sam_model(model_type, checkpoint_path)
526+
527+
# create all prediction folders
528+
for i in range(n_iterations):
529+
os.makedirs(os.path.join(prediction_root, f"iteration{i:02}"), exist_ok=True)
530+
531+
for image_path, gt_path in tqdm(
532+
zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
533+
):
534+
image_name = os.path.basename(image_path)
535+
536+
prediction_paths = [os.path.join(prediction_root, f"iteration{i:02}", image_name) for i in range(n_iterations)]
537+
if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
538+
continue
539+
540+
assert os.path.exists(image_path), image_path
541+
assert os.path.exists(gt_path), gt_path
542+
543+
image = imageio.imread(image_path)
544+
gt = imageio.imread(gt_path).astype("uint32")
545+
gt = relabel_sequential(gt)[0]
546+
547+
with torch.no_grad():
548+
_run_inference_with_iterative_prompting_for_image(
549+
model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size,
550+
)

micro_sam/prompt_generators.py

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -199,56 +199,73 @@ def __init__(self, device=None):
199199
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
200200

201201
def get_positive_points(self, pos_region, overlap_region):
202-
tmp_pos_loc = torch.where(pos_region)
203-
# condiion below where there is no room for improvement for the model
204-
# hence we put a positive point in the "already correct" regions
205-
if torch.stack(tmp_pos_loc).shape[-1] == 0:
206-
tmp_pos_loc = torch.where(overlap_region)
207-
208-
pos_index = np.random.choice(len(tmp_pos_loc[1]))
209-
pos_coordinates = int(tmp_pos_loc[1][pos_index]), int(tmp_pos_loc[2][pos_index])
210-
pos_coordinates = pos_coordinates[::-1]
211-
pos_labels = 1
202+
positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
203+
# we may have objects withput a positive region (= missing true foreground)
204+
# in this case we just sample a point where the model was already correct
205+
positive_locations = [
206+
torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
207+
for pos_loc, ovlp_reg in zip(positive_locations, overlap_region)
208+
]
209+
# we sample one location for each object in the batch
210+
sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations]
211+
# get the corresponding coordinates (Note that we flip the axis order here due to the expected order of SAM)
212+
pos_coordinates = [
213+
[pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices)
214+
]
215+
216+
# make sure that we still have the correct batch size
217+
assert len(pos_coordinates) == pos_region.shape[0]
218+
pos_labels = [1] * len(pos_coordinates)
219+
212220
return pos_coordinates, pos_labels
213221

214-
def get_negative_points(self, neg_region, true_object, gt):
215-
tmp_neg_loc = torch.where(neg_region)
216-
if torch.stack(tmp_neg_loc).shape[-1] == 0:
217-
tmp_true_loc = torch.where(true_object)
218-
x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2]
219-
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
220-
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
221-
bbox_mask = torch.zeros_like(true_object).squeeze(0)
222-
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
223-
bbox_mask = bbox_mask[None].to(self.device)
224-
225-
# NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
226-
# TODO: just expand the pixels of bbox
227-
dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(self.device)).squeeze(0)
228-
background_mask = abs(dilated_bbox_mask - true_object)
229-
tmp_neg_loc = torch.where(background_mask)
230-
231-
# there is a chance that the object is small to not return a decent-sized bounding box
232-
# hence we might not find points sometimes there as well, hence we sample points from true background
233-
if torch.stack(tmp_neg_loc).shape[-1] == 0:
234-
tmp_neg_loc = torch.where(gt == 0)
222+
# TODO get rid of this looped implementation and use proper batched computation instead
223+
def get_negative_points(self, negative_region_batched, true_object_batched, gt_batched):
224+
negative_coordinates, negative_labels = [], []
235225

236-
neg_index = np.random.choice(len(tmp_neg_loc[1]))
237-
neg_coordinates = int(tmp_neg_loc[1][neg_index]), int(tmp_neg_loc[2][neg_index])
238-
neg_coordinates = neg_coordinates[::-1]
239-
neg_labels = 0
226+
for neg_region, true_object, gt in zip(negative_region_batched, true_object_batched, gt_batched):
240227

241-
return neg_coordinates, neg_labels
228+
tmp_neg_loc = torch.where(neg_region)
229+
if torch.stack(tmp_neg_loc).shape[-1] == 0:
230+
tmp_true_loc = torch.where(true_object)
231+
x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2]
232+
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
233+
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
234+
bbox_mask = torch.zeros_like(true_object).squeeze(0)
235+
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
236+
bbox_mask = bbox_mask[None].to(self.device)
237+
238+
# NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
239+
# TODO: just expand the pixels of bbox
240+
dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(self.device)).squeeze(0)
241+
background_mask = abs(dilated_bbox_mask - true_object)
242+
tmp_neg_loc = torch.where(background_mask)
243+
244+
# there is a chance that the object is small to not return a decent-sized bounding box
245+
# hence we might not find points sometimes there as well, hence we sample points from true background
246+
if torch.stack(tmp_neg_loc).shape[-1] == 0:
247+
tmp_neg_loc = torch.where(gt == 0)
248+
249+
neg_index = np.random.choice(len(tmp_neg_loc[1]))
250+
neg_coordinates = [tmp_neg_loc[1][neg_index], tmp_neg_loc[2][neg_index]]
251+
neg_coordinates = neg_coordinates[::-1]
252+
neg_labels = 0
253+
254+
negative_coordinates.append(neg_coordinates)
255+
negative_labels.append(neg_labels)
256+
257+
return negative_coordinates, negative_labels
242258

243259
def __call__(
244-
self,
245-
gt,
246-
object_mask,
247-
current_points,
248-
current_labels
260+
self,
261+
gt,
262+
object_mask,
263+
current_points,
264+
current_labels
249265
):
250266
"""Generate the prompts for each object iteratively in the segmentation.
251267
"""
268+
assert gt.shape == object_mask.shape
252269
true_object = gt.to(self.device)
253270
expected_diff = (object_mask - true_object)
254271
neg_region = (expected_diff == 1).to(torch.float)
@@ -257,8 +274,12 @@ def __call__(
257274

258275
pos_coordinates, pos_labels = self.get_positive_points(pos_region, overlap_region)
259276
neg_coordinates, neg_labels = self.get_negative_points(neg_region, true_object, gt)
277+
assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels)
278+
279+
pos_coordinates, neg_coordinates = torch.tensor(pos_coordinates)[:, None], torch.tensor(neg_coordinates)[:, None]
280+
pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]
260281

261-
net_coords = torch.cat([current_points, torch.tensor([[pos_coordinates, neg_coordinates]])], dim=1)
262-
net_labels = torch.cat([current_labels, torch.tensor([[pos_labels, neg_labels]])], dim=1)
282+
net_coords = torch.cat([current_points, pos_coordinates, neg_coordinates], dim=1)
283+
net_labels = torch.cat([current_labels, pos_labels, neg_labels], dim=1)
263284

264285
return net_coords, net_labels

micro_sam/training/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,9 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None):
153153
gt = gt.squeeze().numpy().astype(np.int32)
154154
point_coordinates, bbox_coordinates = get_centers_and_bounding_boxes(gt)
155155

156+
this_n_samples = len(point_coordinates) if n_samples is None else n_samples
156157
box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists(
157-
gt, n_samples,
158+
gt, this_n_samples,
158159
n_pos, n_neg,
159160
get_boxes,
160161
get_points,

0 commit comments

Comments
 (0)