Skip to content

Commit f984d9c

Browse files
Fix issues with iterative prompting
1 parent 7bd0536 commit f984d9c

File tree

3 files changed

+77
-55
lines changed

3 files changed

+77
-55
lines changed

finetuning/livecell/evaluation/iterative_prompting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def main():
1010

1111
run_inference_with_iterative_prompting(
1212
checkpoint, model_type, image_paths, gt_paths,
13-
prediction_root, use_boxes=False,
13+
prediction_root, use_boxes=False, batch_size=16,
1414
)
1515

1616

micro_sam/evaluation/inference.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -425,10 +425,9 @@ def run_inference_with_prompts(
425425

426426
def _save_segmentation(masks, prediction_path):
427427
# masks to segmentation
428-
masks = masks.numpy().squeeze()
428+
masks = masks.cpu().numpy().squeeze().astype("bool")
429429
shape = masks.shape[-2:]
430-
masks = {"segmentation": mask for mask in masks}
431-
breakpoint()
430+
masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks]
432431
segmentation = mask_data_to_segmentation(masks, shape, with_background=True)
433432
imageio.imwrite(prediction_path, segmentation)
434433

@@ -441,7 +440,7 @@ def _run_inference_with_iterative_prompting_for_image(
441440
device,
442441
use_boxes,
443442
prediction_paths,
444-
batch_size=64
443+
batch_size,
445444
):
446445
assert len(prediction_paths) == n_iterations, f"{len(prediction_paths)}, {n_iterations}"
447446
to_sam_inputs = ConvertToSamInputs()
@@ -453,9 +452,6 @@ def _run_inference_with_iterative_prompting_for_image(
453452

454453
n_pos = 0 if use_boxes else 1
455454
batched_inputs, sampled_ids = to_sam_inputs(image, gt, n_pos=n_pos, n_neg=0, get_boxes=use_boxes)
456-
sampled_binary_y = torch.stack([
457-
torch.stack([_gt == idx for idx in sampled]) for _gt, sampled in zip(gt, sampled_ids)
458-
]).to(torch.float32)
459455

460456
input_images = torch.stack([model.preprocess(x=x["image"].to(device)) for x in batched_inputs], dim=0)
461457
image_embeddings = model.image_embeddings_oft(input_images)
@@ -471,15 +467,19 @@ def _run_inference_with_iterative_prompting_for_image(
471467
for batch_idx in range(n_batches):
472468
batch_start = batch_idx * batch_size
473469
batch_stop = min((batch_idx + 1) * batch_size, n_samples)
474-
tmp_batched_inputs = deepcopy(batched_inputs)
475-
for k, v in tmp_batched_inputs[0].items():
476-
if k == "point_coords":
477-
tmp_batched_inputs[0]["point_coords"] = v[batch_start:batch_stop]
478-
if k == "point_labels":
479-
tmp_batched_inputs[0]["point_labels"] = v[batch_start:batch_stop]
470+
471+
this_batched_inputs = [{
472+
k: v[batch_start:batch_stop] if k in ("point_coords", "point_labels") else v
473+
for k, v in batched_inputs[0].items()
474+
}]
475+
476+
sampled_binary_y = torch.stack([
477+
torch.stack([_gt == idx for idx in sampled[batch_start:batch_stop]])[:, None]
478+
for _gt, sampled in zip(gt, sampled_ids)
479+
]).to(torch.float32)
480480

481481
batched_outputs = model(
482-
tmp_batched_inputs,
482+
this_batched_inputs,
483483
multimask_output=multimasking if iteration == 0 else False,
484484
image_embeddings=image_embeddings
485485
)
@@ -499,7 +499,7 @@ def _run_inference_with_iterative_prompting_for_image(
499499
masks = (masks > 0.5).to(torch.float32)
500500
final_masks.append(masks)
501501

502-
for _pred, _gt, _inp, logits in zip(masks, sampled_binary_y, tmp_batched_inputs, logits_masks):
502+
for _pred, _gt, _inp, logits in zip(masks, sampled_binary_y, this_batched_inputs, logits_masks):
503503
next_coords, next_labels = prompt_generator(_gt, _pred, _inp["point_coords"], _inp["point_labels"])
504504
_inp["point_coords"], _inp["point_labels"], _inp["mask_inputs"] = next_coords, next_labels, logits
505505

@@ -515,6 +515,7 @@ def run_inference_with_iterative_prompting(
515515
prediction_root: Union[str, os.PathLike],
516516
use_boxes: bool,
517517
n_iterations: int = 8,
518+
batch_size: int = 32,
518519
) -> None:
519520
"""
520521
@@ -545,5 +546,5 @@ def run_inference_with_iterative_prompting(
545546

546547
with torch.no_grad():
547548
_run_inference_with_iterative_prompting_for_image(
548-
model, image, gt, n_iterations, device, use_boxes, prediction_paths,
549+
model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size,
549550
)

micro_sam/prompt_generators.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -199,46 +199,62 @@ def __init__(self, device=None):
199199
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
200200

201201
def get_positive_points(self, pos_region, overlap_region):
202-
tmp_pos_loc = torch.where(pos_region)
203-
# condiion below where there is no room for improvement for the model
204-
# hence we put a positive point in the "already correct" regions
205-
if torch.stack(tmp_pos_loc).shape[-1] == 0:
206-
tmp_pos_loc = torch.where(overlap_region)
207-
208-
pos_index = np.random.choice(len(tmp_pos_loc[1]))
209-
pos_coordinates = int(tmp_pos_loc[1][pos_index]), int(tmp_pos_loc[2][pos_index])
210-
pos_coordinates = pos_coordinates[::-1]
211-
pos_labels = 1
202+
positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
203+
# we may have objects withput a positive region (= missing true foreground)
204+
# in this case we just sample a point where the model was already correct
205+
positive_locations = [
206+
torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
207+
for pos_loc, ovlp_reg in zip(positive_locations, overlap_region)
208+
]
209+
# we sample one location for each object in the batch
210+
sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations]
211+
# get the corresponding coordinates (Note that we flip the axis order here due to the expected order of SAM)
212+
pos_coordinates = [
213+
[pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices)
214+
]
215+
216+
# make sure that we still have the correct batch size
217+
assert len(pos_coordinates) == pos_region.shape[0]
218+
pos_labels = [1] * len(pos_coordinates)
219+
212220
return pos_coordinates, pos_labels
213221

214-
def get_negative_points(self, neg_region, true_object, gt):
215-
tmp_neg_loc = torch.where(neg_region)
216-
if torch.stack(tmp_neg_loc).shape[-1] == 0:
217-
tmp_true_loc = torch.where(true_object)
218-
x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2]
219-
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
220-
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
221-
bbox_mask = torch.zeros_like(true_object).squeeze(0)
222-
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
223-
bbox_mask = bbox_mask[None].to(self.device)
224-
225-
# NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
226-
# TODO: just expand the pixels of bbox
227-
dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(self.device)).squeeze(0)
228-
background_mask = abs(dilated_bbox_mask - true_object)
229-
tmp_neg_loc = torch.where(background_mask)
230-
231-
# there is a chance that the object is small to not return a decent-sized bounding box
232-
# hence we might not find points sometimes there as well, hence we sample points from true background
233-
if torch.stack(tmp_neg_loc).shape[-1] == 0:
234-
tmp_neg_loc = torch.where(gt == 0)
222+
# TODO get rid of this looped implementation and use proper batched computation instead
223+
def get_negative_points(self, negative_region_batched, true_object_batched, gt_batched):
224+
negative_coordinates, negative_labels = [], []
235225

236-
neg_index = np.random.choice(len(tmp_neg_loc[1]))
237-
neg_coordinates = int(tmp_neg_loc[1][neg_index]), int(tmp_neg_loc[2][neg_index])
238-
neg_coordinates = neg_coordinates[::-1]
239-
neg_labels = 0
226+
for neg_region, true_object, gt in zip(negative_region_batched, true_object_batched, gt_batched):
240227

241-
return neg_coordinates, neg_labels
228+
tmp_neg_loc = torch.where(neg_region)
229+
if torch.stack(tmp_neg_loc).shape[-1] == 0:
230+
tmp_true_loc = torch.where(true_object)
231+
x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2]
232+
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
233+
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
234+
bbox_mask = torch.zeros_like(true_object).squeeze(0)
235+
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
236+
bbox_mask = bbox_mask[None].to(self.device)
237+
238+
# NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
239+
# TODO: just expand the pixels of bbox
240+
dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(self.device)).squeeze(0)
241+
background_mask = abs(dilated_bbox_mask - true_object)
242+
tmp_neg_loc = torch.where(background_mask)
243+
244+
# there is a chance that the object is small to not return a decent-sized bounding box
245+
# hence we might not find points sometimes there as well, hence we sample points from true background
246+
if torch.stack(tmp_neg_loc).shape[-1] == 0:
247+
tmp_neg_loc = torch.where(gt == 0)
248+
249+
neg_index = np.random.choice(len(tmp_neg_loc[1]))
250+
neg_coordinates = [tmp_neg_loc[1][neg_index], tmp_neg_loc[2][neg_index]]
251+
neg_coordinates = neg_coordinates[::-1]
252+
neg_labels = 0
253+
254+
negative_coordinates.append(neg_coordinates)
255+
negative_labels.append(neg_labels)
256+
257+
return negative_coordinates, negative_labels
242258

243259
def __call__(
244260
self,
@@ -249,6 +265,7 @@ def __call__(
249265
):
250266
"""Generate the prompts for each object iteratively in the segmentation.
251267
"""
268+
assert gt.shape == object_mask.shape
252269
true_object = gt.to(self.device)
253270
expected_diff = (object_mask - true_object)
254271
neg_region = (expected_diff == 1).to(torch.float)
@@ -257,8 +274,12 @@ def __call__(
257274

258275
pos_coordinates, pos_labels = self.get_positive_points(pos_region, overlap_region)
259276
neg_coordinates, neg_labels = self.get_negative_points(neg_region, true_object, gt)
277+
assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels)
278+
279+
pos_coordinates, neg_coordinates = torch.tensor(pos_coordinates)[:, None], torch.tensor(neg_coordinates)[:, None]
280+
pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]
260281

261-
net_coords = torch.cat([current_points, torch.tensor([[pos_coordinates, neg_coordinates]])], dim=1)
262-
net_labels = torch.cat([current_labels, torch.tensor([[pos_labels, neg_labels]])], dim=1)
282+
net_coords = torch.cat([current_points, pos_coordinates, neg_coordinates], dim=1)
283+
net_labels = torch.cat([current_labels, pos_labels, neg_labels], dim=1)
263284

264285
return net_coords, net_labels

0 commit comments

Comments
 (0)