-
Notifications
You must be signed in to change notification settings - Fork 88
Description
Hi Anwai and Constantin! Hope you are both doing well this winter season. I have made some headway in finetuning a decent model, but am running into issues deploying it on other datasets. Below I describe my overall task and how I am trying to get there with microSAM.
Big Picture: from whole embryo LSFM data, predict the region where the somite structures are likely to be (semantic segmentation). The somite regions are sparse, about 10% of the total embryo (but my training data is dense; only planes that have masks are passed). Then, from that foreground, separate masks into instances. I tested code using notebooks but do the full training in a python scripts submitted to our cluster. The version of microSAM I'm using is 1.6.2.
Current progress: I have successfully fine-tuned a Micro-SAM vit_b_lm model for somite segmentation on 2D slices from volumetric mouse embryo data with great IOU scores (around 0.7-0.8). The results look fantastic (below), but are purely semantic (0=background and 1=foreground). I should be able to separate these into instances (which should be achievable by microSAM prompting, right?). I'd like to use this checkpoint to run inference on all unseen data (10 embryos) to extract foreground/semantic masks as a .tiff. Then that can be used as input for the prompting to get instance masks (I think...).
The first step is to run inference. I use to have notebook code working with an earlier version of microSAM, but now I get unpickling errors (below). I've brainstormed with ChatGPT but it is leading me out of the microSAM API. Its consensus is that the checkpoints file is not sufficient to run inference outside of the training environment. I feel like this cannot be since my purpose is straightforward. So, I'm probably missing something critical. How do I use a checkpoints file from training to run inference outside of the training environment? ChatGPT said that current functions in the API document are incompatible with my intended purpose, which makes me question if I'm using microSAM in a way that is "off label use". I could use a bit of help deciding a path forward with the checkpoint that I have, and if you think my approach is infeasible to get instance masks.
Thanks for reading!
The Code (ran as python script, but worked in a notebook in an earlier version of microSAM)
# imports
import zarr
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
def run_automatic_instance_segmentation(
image: np.ndarray,
checkpoint_path: Union[os.PathLike ,str],
model_type: str = "vit_b_lm",
device: Optional[Union[str, torch.device]] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
):
"""Automatic Instance Segmentation (AIS) by training an additional instance decoder in SAM.
NOTE: AIS is supported only for `µsam` models.
Args:
image: The input image.
checkpoint_path: The path to stored checkpoints.
model_type: The choice of the `µsam` model.
device: The device to run the model inference.
tile_shape: The tile shape for tiling-based segmentation.
halo: The overlap shape on each side per tile for stitching the segmented tiles.
Returns:
The instance segmentation.
"""
# Step 1: Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.
predictor, segmenter = get_predictor_and_segmenter(
model_type=model_type, # choice of the Segment Anything model
checkpoint=checkpoint_path, # overwrite to pass your own finetuned model.
device=device, # the device to run the model inference.
is_tiled=(tile_shape is not None), # whether to run automatic segmentation.
)
# Step 2: Get the instance segmentation for the given image.
prediction = automatic_instance_segmentation(
predictor=predictor, # the predictor for the Segment Anything model.
segmenter=segmenter, # the segmenter class responsible for generating predictions.
input_path=image, # the filepath to image or the input array for automatic segmentation.
ndim=2, # the number of input dimensions.
tile_shape=tile_shape, # the tile shape for tiling-based prediction.
halo=halo, # the overlap shape for tiling-based prediction.
)
return prediction
# Path definition
raw_path = "/net/beliveau/vol2/instrument/E9.5_300/Zoom_300/dataset_fused_rechunked_blocks.zarr"
raw_key = "ch2/s3"
checkpoint_path = ("/net/beliveau/vol1/project/EKN_whole_embryo_structural_analysis/Somites/uSAM_finetuning/checkpoints/somite_usam_finetune_old_code_test_5kiter_1e-5lr_vit_b_LM/best.pt")
raw = zarr.open(raw_path, mode="r")[raw_key]
print(f"Loaded raw volume: {raw.shape}, dtype={raw.dtype}")
# defining padding function
def center_pad(image, target_shape=(1024, 1024)):
x, y = image.shape
tx, ty = target_shape
pad_x = max((tx - x) // 2, 0)
pad_y = max((ty - y) // 2, 0)
pad_width = ((pad_x, tx - x - pad_x), (pad_y, ty - y - pad_y))
return np.pad(image, pad_width, mode="constant", constant_values=0)
# running inference on all the slices in the z-stack
device = "cuda" if torch.cuda.is_available() else "cpu"
n_slices = raw.shape[0]
pred_stack = []
print(f"Running inference on {n_slices} slices...")
for z in range(n_slices):
raw_slice = raw[z, :, :].astype(np.float32)
padded = center_pad(raw_slice)
if padded.max() > 0:
normed = (padded - padded.min()) / (padded.max() - padded.min())
else:
normed = np.zeros_like(padded)
img_3ch = np.repeat(normed[:, :, None], 3, axis=2).astype(np.float32)
pred_mask_bin = run_automatic_instance_segmentation(
image=img_3ch,
checkpoint_path=checkpoint_path,
model_type="vit_b",
device=device
)
pred_mask_bin = pred_mask_bin.astype(bool)
pred_stack.append(pred_mask_bin)
pred_stack = np.stack(pred_stack)
print("Inference complete!")
print(f"Predicted stack shape: {pred_stack.shape}")
Error:
---------------------------------------------------------------------------
Traceback (most recent call last):
File "/net/beliveau/vol1/project/EKN_whole_embryo_structural_analysis/Somites/uSAM_finetuning/microSAM_inference_to_instance_mask_script.py", line 54, in <module>
sam, _ = get_sam_model(
~~~~~~~~~~~~~^
model_type="vit_b_lm",
^^^^^^^^^^^^^^^^^^^^^^
...<2 lines>...
return_sam=True
^^^^^^^^^^^^^^^
)
^
File "/net/beliveau/vol1/home/eknich/.conda/envs/sam/lib/python3.13/site-packages/micro_sam/util.py", line 425, in get_sam_model
sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
File "/net/beliveau/vol1/home/eknich/.conda/envs/sam/lib/python3.13/site-packages/micro_sam/models/build_sam.py", line 68, in build_sam_vit_b
return _build_sam(
encoder_embed_dim=768,
...<5 lines>...
image_size=image_size,
)
File "/net/beliveau/vol1/home/eknich/.conda/envs/sam/lib/python3.13/site-packages/micro_sam/models/build_sam.py", line 139, in _build_sam
state_dict = torch.load(f)
File "/net/beliveau/vol1/home/eknich/.conda/envs/sam/lib/python3.13/site-packages/torch/serialization.py", line 1470, in load
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, �[1mdo those steps only if you trust the source of the checkpoint�[0m.
(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL __main__.SomiteSegmentationDataset was not an allowed global by default. Please use `torch.serialization.add_safe_globals([SomiteSegmentationDataset])` or the `torch.serialization.safe_globals([SomiteSegmentationDataset])` context manager to allowlist this global if you trust this class/function.
Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
