Skip to content

Troubles using checkpoint file to headlessly run inference on unseen data (unpickling errors?) #1132

@evaknichols

Description

@evaknichols

Hi Anwai and Constantin! Hope you are both doing well this winter season. I have made some headway in finetuning a decent model, but am running into issues deploying it on other datasets. Below I describe my overall task and how I am trying to get there with microSAM.

Big Picture: from whole embryo LSFM data, predict the region where the somite structures are likely to be (semantic segmentation). The somite regions are sparse, about 10% of the total embryo (but my training data is dense; only planes that have masks are passed). Then, from that foreground, separate masks into instances. I tested code using notebooks but do the full training in a python scripts submitted to our cluster. The version of microSAM I'm using is 1.6.2.

Current progress: I have successfully fine-tuned a Micro-SAM vit_b_lm model for somite segmentation on 2D slices from volumetric mouse embryo data with great IOU scores (around 0.7-0.8). The results look fantastic (below), but are purely semantic (0=background and 1=foreground). I should be able to separate these into instances (which should be achievable by microSAM prompting, right?). I'd like to use this checkpoint to run inference on all unseen data (10 embryos) to extract foreground/semantic masks as a .tiff. Then that can be used as input for the prompting to get instance masks (I think...).

The first step is to run inference. I use to have notebook code working with an earlier version of microSAM, but now I get unpickling errors (below). I've brainstormed with ChatGPT but it is leading me out of the microSAM API. Its consensus is that the checkpoints file is not sufficient to run inference outside of the training environment. I feel like this cannot be since my purpose is straightforward. So, I'm probably missing something critical. How do I use a checkpoints file from training to run inference outside of the training environment? ChatGPT said that current functions in the API document are incompatible with my intended purpose, which makes me question if I'm using microSAM in a way that is "off label use". I could use a bit of help deciding a path forward with the checkpoint that I have, and if you think my approach is infeasible to get instance masks.

Thanks for reading!

The Code (ran as python script, but worked in a notebook in an earlier version of microSAM)

# imports
import zarr
import numpy as np
import torch
import matplotlib.pyplot as plt
import os 

def run_automatic_instance_segmentation(
    image: np.ndarray,
    checkpoint_path: Union[os.PathLike ,str],
    model_type: str = "vit_b_lm",
    device: Optional[Union[str, torch.device]] = None,
    tile_shape: Optional[Tuple[int, int]] = None,
    halo: Optional[Tuple[int, int]] = None,
):
    """Automatic Instance Segmentation (AIS) by training an additional instance decoder in SAM.

    NOTE: AIS is supported only for `µsam` models.

    Args:
        image: The input image.
        checkpoint_path: The path to stored checkpoints.
        model_type: The choice of the `µsam` model.
        device: The device to run the model inference.
        tile_shape: The tile shape for tiling-based segmentation.
        halo: The overlap shape on each side per tile for stitching the segmented tiles.

    Returns:
        The instance segmentation.
    """
    # Step 1: Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.
    predictor, segmenter = get_predictor_and_segmenter(
        model_type=model_type,  # choice of the Segment Anything model
        checkpoint=checkpoint_path,  # overwrite to pass your own finetuned model.
        device=device,  # the device to run the model inference.
        is_tiled=(tile_shape is not None),  # whether to run automatic segmentation.
    )

    # Step 2: Get the instance segmentation for the given image.
    prediction = automatic_instance_segmentation(
        predictor=predictor,  # the predictor for the Segment Anything model.
        segmenter=segmenter,  # the segmenter class responsible for generating predictions.
        input_path=image,  # the filepath to image or the input array for automatic segmentation.
        ndim=2,  # the number of input dimensions.
        tile_shape=tile_shape,  # the tile shape for tiling-based prediction.
        halo=halo,  # the overlap shape for tiling-based prediction.
    )

    return prediction

# Path definition
raw_path = "/net/beliveau/vol2/instrument/E9.5_300/Zoom_300/dataset_fused_rechunked_blocks.zarr"
raw_key = "ch2/s3"

checkpoint_path = ("/net/beliveau/vol1/project/EKN_whole_embryo_structural_analysis/Somites/uSAM_finetuning/checkpoints/somite_usam_finetune_old_code_test_5kiter_1e-5lr_vit_b_LM/best.pt")

raw = zarr.open(raw_path, mode="r")[raw_key]
print(f"Loaded raw volume: {raw.shape}, dtype={raw.dtype}")

# defining padding function 
def center_pad(image, target_shape=(1024, 1024)):
    x, y = image.shape
    tx, ty = target_shape
    pad_x = max((tx - x) // 2, 0)
    pad_y = max((ty - y) // 2, 0)
    pad_width = ((pad_x, tx - x - pad_x), (pad_y, ty - y - pad_y))
    return np.pad(image, pad_width, mode="constant", constant_values=0)

# running inference on all the slices in the z-stack
device = "cuda" if torch.cuda.is_available() else "cpu"
n_slices = raw.shape[0]
pred_stack = []

print(f"Running inference on {n_slices} slices...")

for z in range(n_slices):
    raw_slice = raw[z, :, :].astype(np.float32)
    padded = center_pad(raw_slice)
    if padded.max() > 0:
        normed = (padded - padded.min()) / (padded.max() - padded.min())
    else:
        normed = np.zeros_like(padded)

    img_3ch = np.repeat(normed[:, :, None], 3, axis=2).astype(np.float32)

    pred_mask_bin = run_automatic_instance_segmentation(
        image=img_3ch,
        checkpoint_path=checkpoint_path,
        model_type="vit_b",
        device=device
    )
    pred_mask_bin = pred_mask_bin.astype(bool)
    pred_stack.append(pred_mask_bin)

pred_stack = np.stack(pred_stack)
print("Inference complete!")
print(f"Predicted stack shape: {pred_stack.shape}")

Error:

---------------------------------------------------------------------------
Traceback (most recent call last):
  File "/net/beliveau/vol1/project/EKN_whole_embryo_structural_analysis/Somites/uSAM_finetuning/microSAM_inference_to_instance_mask_script.py", line 54, in <module>
    sam, _ = get_sam_model(
             ~~~~~~~~~~~~~^
        model_type="vit_b_lm",
        ^^^^^^^^^^^^^^^^^^^^^^
    ...<2 lines>...
        return_sam=True
        ^^^^^^^^^^^^^^^
    )
    ^
  File "/net/beliveau/vol1/home/eknich/.conda/envs/sam/lib/python3.13/site-packages/micro_sam/util.py", line 425, in get_sam_model
    sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
  File "/net/beliveau/vol1/home/eknich/.conda/envs/sam/lib/python3.13/site-packages/micro_sam/models/build_sam.py", line 68, in build_sam_vit_b
    return _build_sam(
        encoder_embed_dim=768,
    ...<5 lines>...
        image_size=image_size,
    )
  File "/net/beliveau/vol1/home/eknich/.conda/envs/sam/lib/python3.13/site-packages/micro_sam/models/build_sam.py", line 139, in _build_sam
    state_dict = torch.load(f)
  File "/net/beliveau/vol1/home/eknich/.conda/envs/sam/lib/python3.13/site-packages/torch/serialization.py", line 1470, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, �[1mdo those steps only if you trust the source of the checkpoint�[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL __main__.SomiteSegmentationDataset was not an allowed global by default. Please use `torch.serialization.add_safe_globals([SomiteSegmentationDataset])` or the `torch.serialization.safe_globals([SomiteSegmentationDataset])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions