|
1 | 1 | import os |
2 | 2 | import pickle |
| 3 | +import warnings |
3 | 4 |
|
4 | 5 | from copy import deepcopy |
5 | 6 | from typing import Any, Dict, List, Optional, Union |
|
15 | 16 | from segment_anything.utils.transforms import ResizeLongestSide |
16 | 17 |
|
17 | 18 | from .. import util as util |
18 | | -from ..training import get_trainable_sam_model, ConvertToSamInputs |
| 19 | +from ..instance_segmentation import mask_data_to_segmentation |
19 | 20 | from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator |
| 21 | +from ..training import get_trainable_sam_model, ConvertToSamInputs |
20 | 22 |
|
21 | 23 |
|
22 | 24 | def _load_prompts( |
@@ -422,48 +424,127 @@ def run_inference_with_prompts( |
422 | 424 | pickle.dump(cached_box_prompts, f) |
423 | 425 |
|
424 | 426 |
|
425 | | -def run_inference_with_iterative_prompting( |
426 | | - image, gt, model_type, checkpoint_path, n_iterations, n_positive, n_negative, |
427 | | - use_boxes, device=None, _sigmoid=torch.nn.Sigmoid() |
| 427 | +def _save_segmentation(masks, prediction_path): |
| 428 | + # masks to segmentation |
| 429 | + masks = masks.cpu().numpy().squeeze().astype("bool") |
| 430 | + shape = masks.shape[-2:] |
| 431 | + masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks] |
| 432 | + segmentation = mask_data_to_segmentation(masks, shape, with_background=True) |
| 433 | + imageio.imwrite(prediction_path, segmentation) |
| 434 | + |
| 435 | + |
| 436 | +def _run_inference_with_iterative_prompting_for_image( |
| 437 | + model, |
| 438 | + image, |
| 439 | + gt, |
| 440 | + n_iterations, |
| 441 | + device, |
| 442 | + use_boxes, |
| 443 | + prediction_paths, |
| 444 | + batch_size, |
428 | 445 | ): |
429 | | - if device is None: |
430 | | - device = "cuda" if torch.cuda.is_available() else "cpu" |
| 446 | + assert len(prediction_paths) == n_iterations, f"{len(prediction_paths)}, {n_iterations}" |
| 447 | + to_sam_inputs = ConvertToSamInputs() |
431 | 448 |
|
432 | | - model = get_trainable_sam_model(model_type, checkpoint_path) |
433 | | - _to_sam_inputs = ConvertToSamInputs() |
434 | | - batched_inputs, sampled_ids = _to_sam_inputs(image, gt, n_positive, n_negative, use_boxes) |
435 | | - sampled_binary_y = [np.isin(gt, idx) for idx in sampled_ids] |
| 449 | + image = torch.from_numpy( |
| 450 | + image[None, None] if image.ndim == 2 else image[None] |
| 451 | + ) |
| 452 | + gt = torch.from_numpy(gt[None].astype("int32")) |
| 453 | + |
| 454 | + n_pos = 0 if use_boxes else 1 |
| 455 | + batched_inputs, sampled_ids = to_sam_inputs(image, gt, n_pos=n_pos, n_neg=0, get_boxes=use_boxes) |
436 | 456 |
|
437 | 457 | input_images = torch.stack([model.preprocess(x=x["image"].to(device)) for x in batched_inputs], dim=0) |
438 | 458 | image_embeddings = model.image_embeddings_oft(input_images) |
439 | 459 |
|
| 460 | + multimasking = n_pos == 1 |
440 | 461 | prompt_generator = IterativePromptGenerator(device) |
441 | 462 |
|
442 | | - multimasking = False |
443 | | - if n_positive == 1 and n_negative == 0: |
444 | | - if not use_boxes: |
445 | | - multimasking = True |
| 463 | + n_samples = len(sampled_ids[0]) |
| 464 | + n_batches = int(np.ceil(float(n_samples) / batch_size)) |
446 | 465 |
|
447 | 466 | for iteration in range(n_iterations): |
448 | | - batched_outputs = model( |
449 | | - batched_inputs, |
450 | | - multimask_output=multimasking if iteration == 0 else False, |
451 | | - image_embeddings=image_embeddings |
452 | | - ) |
| 467 | + final_masks = [] |
| 468 | + for batch_idx in range(n_batches): |
| 469 | + batch_start = batch_idx * batch_size |
| 470 | + batch_stop = min((batch_idx + 1) * batch_size, n_samples) |
| 471 | + |
| 472 | + this_batched_inputs = [{ |
| 473 | + k: v[batch_start:batch_stop] if k in ("point_coords", "point_labels") else v |
| 474 | + for k, v in batched_inputs[0].items() |
| 475 | + }] |
| 476 | + |
| 477 | + sampled_binary_y = torch.stack([ |
| 478 | + torch.stack([_gt == idx for idx in sampled[batch_start:batch_stop]])[:, None] |
| 479 | + for _gt, sampled in zip(gt, sampled_ids) |
| 480 | + ]).to(torch.float32) |
| 481 | + |
| 482 | + batched_outputs = model( |
| 483 | + this_batched_inputs, |
| 484 | + multimask_output=multimasking if iteration == 0 else False, |
| 485 | + image_embeddings=image_embeddings |
| 486 | + ) |
| 487 | + |
| 488 | + masks, logits_masks = [], [] |
| 489 | + for m in batched_outputs: |
| 490 | + mask, l_mask = [], [] |
| 491 | + for _m, _l, _iou in zip(m["masks"], m["low_res_masks"], m["iou_predictions"]): |
| 492 | + best_iou_idx = torch.argmax(_iou) |
| 493 | + mask.append(torch.sigmoid(_m[best_iou_idx][None])) |
| 494 | + l_mask.append(_l[best_iou_idx][None]) |
| 495 | + mask, l_mask = torch.stack(mask), torch.stack(l_mask) |
| 496 | + masks.append(mask) |
| 497 | + logits_masks.append(l_mask) |
| 498 | + |
| 499 | + masks, logits_masks = torch.stack(masks), torch.stack(logits_masks) |
| 500 | + masks = (masks > 0.5).to(torch.float32) |
| 501 | + final_masks.append(masks) |
| 502 | + |
| 503 | + for _pred, _gt, _inp, logits in zip(masks, sampled_binary_y, this_batched_inputs, logits_masks): |
| 504 | + next_coords, next_labels = prompt_generator(_gt, _pred, _inp["point_coords"], _inp["point_labels"]) |
| 505 | + _inp["point_coords"], _inp["point_labels"], _inp["mask_inputs"] = next_coords, next_labels, logits |
| 506 | + |
| 507 | + final_masks = torch.cat(final_masks, dim=1) |
| 508 | + _save_segmentation(final_masks, prediction_paths[iteration]) |
| 509 | + |
453 | 510 |
|
454 | | - masks, logits_masks = [], [] |
455 | | - for m in batched_outputs: |
456 | | - mask, l_mask = [], [] |
457 | | - for _m, _l, _iou in zip(m["masks"], m["low_res_masks"], m["iou_predictions"]): |
458 | | - best_iou_idx = torch.argmax(_iou) |
459 | | - mask.append(_sigmoid(_m[best_iou_idx][None])) |
460 | | - l_mask.append(_l[best_iou_idx][None]) |
461 | | - mask, l_mask = torch.stack(mask), torch.stack(l_mask) |
462 | | - masks.append(mask) |
463 | | - logits_masks.append(l_mask) |
464 | | - masks, logits_masks = torch.stack(masks), torch.stack(logits_masks) |
465 | | - masks = (masks > 0.5).to(torch.float32) |
466 | | - |
467 | | - for _pred, _gt, _inp, logits in zip(masks, sampled_binary_y, batched_inputs, logits_masks): |
468 | | - net_coords, net_labels = prompt_generator(_gt, _pred, _inp["point_coords"], _inp["point_labels"]) |
469 | | - _inp["point_coords"], _inp["point_labels"], _inp["mask_inputs"] = net_coords, net_labels, logits |
| 511 | +def run_inference_with_iterative_prompting( |
| 512 | + checkpoint_path: Union[str, os.PathLike], |
| 513 | + model_type: str, |
| 514 | + image_paths: List[Union[str, os.PathLike]], |
| 515 | + gt_paths: List[Union[str, os.PathLike]], |
| 516 | + prediction_root: Union[str, os.PathLike], |
| 517 | + use_boxes: bool, |
| 518 | + n_iterations: int = 8, |
| 519 | + batch_size: int = 32, |
| 520 | +) -> None: |
| 521 | + """@private""" |
| 522 | + warnings.warn("The iterative prompting functionality is not working correctly yet.") |
| 523 | + |
| 524 | + device = torch.device("cuda") |
| 525 | + model = get_trainable_sam_model(model_type, checkpoint_path) |
| 526 | + |
| 527 | + # create all prediction folders |
| 528 | + for i in range(n_iterations): |
| 529 | + os.makedirs(os.path.join(prediction_root, f"iteration{i:02}"), exist_ok=True) |
| 530 | + |
| 531 | + for image_path, gt_path in tqdm( |
| 532 | + zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" |
| 533 | + ): |
| 534 | + image_name = os.path.basename(image_path) |
| 535 | + |
| 536 | + prediction_paths = [os.path.join(prediction_root, f"iteration{i:02}", image_name) for i in range(n_iterations)] |
| 537 | + if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): |
| 538 | + continue |
| 539 | + |
| 540 | + assert os.path.exists(image_path), image_path |
| 541 | + assert os.path.exists(gt_path), gt_path |
| 542 | + |
| 543 | + image = imageio.imread(image_path) |
| 544 | + gt = imageio.imread(gt_path).astype("uint32") |
| 545 | + gt = relabel_sequential(gt)[0] |
| 546 | + |
| 547 | + with torch.no_grad(): |
| 548 | + _run_inference_with_iterative_prompting_for_image( |
| 549 | + model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size, |
| 550 | + ) |
0 commit comments