Skip to content

Commit b5face5

Browse files
Merge pull request #150 from computational-cell-analytics/finetune-example
Finetune example
2 parents 4c48e68 + c34d219 commit b5face5

File tree

8 files changed

+230
-17
lines changed

8 files changed

+230
-17
lines changed

README.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@ We implement napari applications for:
1616
<img src="https://github.com/computational-cell-analytics/micro-sam/assets/4263537/dfca3d9b-dba5-440b-b0f9-72a0683ac410" width="256">
1717
<img src="https://github.com/computational-cell-analytics/micro-sam/assets/4263537/aefbf99f-e73a-4125-bb49-2e6592367a64" width="256">
1818

19-
**Beta version**
20-
21-
This is an advanced beta version. While many features are still under development, we aim to keep the user interface and python library stable.
22-
Any feedback is welcome, but please be aware that the functionality is under active development and that some features may not be thoroughly tested yet.
23-
We will soon provide a stand-alone application for running the `micro_sam` annotation tools, and plan to also release it as [napari plugin](https://napari.org/stable/plugins/index.html) in the future.
24-
25-
If you run into any problems or have questions please open an issue on Github or reach out via [image.sc](https://forum.image.sc/) using the tag `micro-sam` and tagging @constantinpape.
19+
If you run into any problems or have questions regarding our tool please open an issue on Github or reach out via [image.sc](https://forum.image.sc/) using the tag `micro-sam` and tagging @constantinpape.
2620

2721

2822
## Installation and Usage

examples/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,8 @@ Examples for using the micro_sam annotation tools:
66
- `annotator_tracking.py`: run the interactive tracking annotation tool
77
- `image_series_annotator.py`: run the annotation tool for a series of images
88

9+
The folder `finetuning` contains example scripts that show how a Segment Anything model can be fine-tuned
10+
on custom data with the `micro_sam.train` library, and how the finetuned models can then be used within the annotatin tools.
11+
912
The folder `use_as_library` contains example scripts that show how `micro_sam` can be used as a python
1013
library to apply Segment Anything to mult-dimensional data.

examples/finetuning/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
checkpoints/
2+
logs/
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
3+
import numpy as np
4+
import torch
5+
import torch_em
6+
7+
import micro_sam.training as sam_training
8+
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
9+
from micro_sam.util import export_custom_sam_model
10+
11+
DATA_FOLDER = "data"
12+
13+
14+
def get_dataloader(split, patch_shape, batch_size):
15+
"""Return train or val data loader for finetuning SAM.
16+
17+
The data loader must be a torch data loader that retuns `x, y` tensors,
18+
where `x` is the image data and `y` are the labels.
19+
The labels have to be in a label mask instance segmentation format.
20+
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.
21+
Important: the ID 0 is reseved for background, and the IDs must be consecutive
22+
23+
Here, we use `torch_em.default_segmentation_loader` for creating a suitable data loader from
24+
the example hela data. You can either adapt this for your own data (see comments below)
25+
or write a suitable torch dataloader yourself.
26+
"""
27+
assert split in ("train", "val")
28+
os.makedirs(DATA_FOLDER, exist_ok=True)
29+
30+
# This will download the image and segmentation data for training.
31+
image_dir = fetch_tracking_example_data(DATA_FOLDER)
32+
segmentation_dir = fetch_tracking_segmentation_data(DATA_FOLDER)
33+
34+
# torch_em.default_segmentation_loader is a convenience function to build a torch dataloader
35+
# from image data and labels for training segmentation models.
36+
# It supports image data in various formats. Here, we load image data and labels from the two
37+
# folders with tif images that were downloaded by the example data functionality, by specifying
38+
# `raw_key` and `label_key` as `*.tif`. This means all images in the respective folders that end with
39+
# .tif will be loadded.
40+
# The function supports many other file formats. For example, if you have tif stacks with multiple slices
41+
# instead of multiple tif images in a foldder, then you can pass raw_key=label_key=None.
42+
43+
# Load images from multiple files in folder via pattern (here: all tif files)
44+
raw_key, label_key = "*.tif", "*.tif"
45+
# Alternative: if you have tif stacks you can just set raw_key and label_key to None
46+
# raw_key, label_key= None, None
47+
48+
# The 'roi' argument can be used to subselect parts of the data.
49+
# Here, we use it to select the first 70 frames fro the test split and the other frames for the val split.
50+
if split == "train":
51+
roi = np.s_[:70, :, :]
52+
else:
53+
roi = np.s_[70:, :, :]
54+
55+
loader = torch_em.default_segmentation_loader(
56+
raw_paths=image_dir, raw_key=raw_key,
57+
label_paths=segmentation_dir, label_key=label_key,
58+
patch_shape=patch_shape, batch_size=batch_size,
59+
ndim=2, is_seg_dataset=True, rois=roi,
60+
label_transform=torch_em.transform.label.connected_components,
61+
)
62+
return loader
63+
64+
65+
def run_training(checkpoint_name, model_type):
66+
"""Run the actual model training."""
67+
68+
# All hyperparameters for training.
69+
batch_size = 1 # the training batch size
70+
patch_shape = (1, 512, 512) # the size of patches for training
71+
n_objects_per_batch = 25 # the number of objects per batch that will be sampled
72+
device = torch.device("cuda") # the device/GPU used for training
73+
n_iterations = 10000 # how long we train (in iterations)
74+
75+
# Get the dataloaders.
76+
train_loader = get_dataloader("train", patch_shape, batch_size)
77+
val_loader = get_dataloader("val", patch_shape, batch_size)
78+
79+
# Get the segment anything model, the optimizer and the LR scheduler
80+
model = sam_training.get_trainable_sam_model(model_type=model_type, device=device)
81+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
82+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
83+
84+
# This class creates all the training data for a batch (inputs, prompts and labels).
85+
convert_inputs = sam_training.ConvertToSamInputs()
86+
87+
# the trainer which performs training and validation (implemented using "torch_em")
88+
trainer = sam_training.SamTrainer(
89+
name=checkpoint_name,
90+
train_loader=train_loader,
91+
val_loader=val_loader,
92+
model=model,
93+
optimizer=optimizer,
94+
# currently we compute loss batch-wise, else we pass channelwise True
95+
loss=torch_em.loss.DiceLoss(channelwise=False),
96+
metric=torch_em.loss.DiceLoss(),
97+
device=device,
98+
lr_scheduler=scheduler,
99+
logger=sam_training.SamLogger,
100+
log_image_interval=10,
101+
mixed_precision=True,
102+
convert_inputs=convert_inputs,
103+
n_objects_per_batch=n_objects_per_batch,
104+
n_sub_iteration=8,
105+
compile_model=False
106+
)
107+
trainer.fit(n_iterations)
108+
109+
110+
def export_model(checkpoint_name, model_type):
111+
"""Export the trained model."""
112+
# export the model after training so that it can be used by the rest of the micro_sam library
113+
export_path = "./finetuned_hela_model.pth"
114+
checkpoint_path = os.path.join("checkpoints", checkpoint_name, "best.pt")
115+
export_custom_sam_model(
116+
checkpoint_path=checkpoint_path,
117+
model_type=model_type,
118+
save_path=export_path,
119+
)
120+
121+
122+
def main():
123+
"""Finetune a Segment Anything model.
124+
125+
This example uses image data and segmentations from the cell tracking challenge,
126+
but can easily be adapted for other data (including data you have annoated with micro_sam beforehand).
127+
"""
128+
# The model_type determines which base model is used to initialize the weights that are finetuned.
129+
# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
130+
model_type = "vit_b"
131+
132+
# The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
133+
checkpoint_name = "sam_hela"
134+
135+
run_training(checkpoint_name, model_type)
136+
export_model(checkpoint_name, model_type)
137+
138+
139+
if __name__ == "__main__":
140+
main()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import imageio.v3 as imageio
2+
3+
import micro_sam.util as util
4+
from micro_sam.sam_annotator import annotator_2d
5+
6+
7+
def run_annotator_with_custom_model():
8+
"""Run the 2d anntator with a custom (finetuned) model.
9+
10+
Here, we use the model that is produced by `finetuned_hela.py` and apply it
11+
for an image from the validation set.
12+
"""
13+
# take the last frame, which is part of the val set, so the model was not directly trained on it
14+
im = imageio.imread("./data/DIC-C2DH-HeLa.zip.unzip/DIC-C2DH-HeLa/01/t083.tif")
15+
16+
# set the checkpoint and the path for caching the embeddings
17+
checkpoint = "./finetuned_hela_model.pth"
18+
embedding_path = "./embeddings/embeddings-finetuned.zarr"
19+
20+
model_type = "vit_b" # We finetune a vit_b in the example script.
21+
# Adapt this if you finetune a different model type, e.g. vit_h.
22+
23+
# Load the custom model.
24+
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
25+
26+
# Run the 2d annotator with the custom model.
27+
annotator_2d(
28+
im, embedding_path=embedding_path, predictor=predictor, precompute_amg_state=True,
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_annotator_with_custom_model()

finetuning/README.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Segment Anything Finetuning
22

3-
Preliminary examples for fine-tuning segment anything on custom datasets.
3+
Code for finetuning segment anything data on microscopy data and evaluating the finetuned models.
44

5-
## LiveCELL
5+
## Example: LiveCELL
66

77
**Finetuning**
88

@@ -47,3 +47,12 @@ E.g. run the script like below to evaluate the previous predictions.
4747
python livecell_evaluation.py -i /scratch/projects/nim00007/data/LiveCELL -e experiment
4848
```
4949
This will create a folder `experiment/results` with csv tables with the results per cell type and averaged over all images.
50+
51+
52+
## Finetuning and evaluation code
53+
54+
The subfolders contain the code for different finetuning and evaluation experiments for microscopy data:
55+
- `livecell`: TODO
56+
- `generalist`: TODO
57+
58+
Note: we still need to clean up most of this code and will add it later.

micro_sam/sample_data.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pooch
1010

1111

12-
def fetch_image_series_example_data(save_directory: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
12+
def fetch_image_series_example_data(save_directory: Union[str, os.PathLike]) -> str:
1313
"""Download the sample images for the image series annotator.
1414
1515
Args:
@@ -36,7 +36,7 @@ def fetch_image_series_example_data(save_directory: Union[str, os.PathLike]) ->
3636
return data_folder
3737

3838

39-
def fetch_wholeslide_example_data(save_directory: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
39+
def fetch_wholeslide_example_data(save_directory: Union[str, os.PathLike]) -> str:
4040
"""Download the sample data for the 2d annotator.
4141
4242
This downloads part of a whole-slide image from the NeurIPS Cell Segmentation Challenge.
@@ -61,7 +61,7 @@ def fetch_wholeslide_example_data(save_directory: Union[str, os.PathLike]) -> Un
6161
return os.path.join(save_directory, fname)
6262

6363

64-
def fetch_livecell_example_data(save_directory: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
64+
def fetch_livecell_example_data(save_directory: Union[str, os.PathLike]) -> str:
6565
"""Download the sample data for the 2d annotator.
6666
6767
This downloads a single image from the LiveCELL dataset.
@@ -86,7 +86,7 @@ def fetch_livecell_example_data(save_directory: Union[str, os.PathLike]) -> Unio
8686
return os.path.join(save_directory, fname)
8787

8888

89-
def fetch_hela_2d_example_data(save_directory: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
89+
def fetch_hela_2d_example_data(save_directory: Union[str, os.PathLike]) -> str:
9090
"""Download the sample data for the 2d annotator.
9191
9292
This downloads a single image from the HeLa CTC dataset.
@@ -110,7 +110,7 @@ def fetch_hela_2d_example_data(save_directory: Union[str, os.PathLike]) -> Union
110110
return os.path.join(save_directory, fname)
111111

112112

113-
def fetch_3d_example_data(save_directory: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
113+
def fetch_3d_example_data(save_directory: Union[str, os.PathLike]) -> str:
114114
"""Download the sample data for the 3d annotator.
115115
116116
This downloads the Lucchi++ datasets from https://casser.io/connectomics/.
@@ -139,7 +139,7 @@ def fetch_3d_example_data(save_directory: Union[str, os.PathLike]) -> Union[str,
139139
return str(lucchi_dir)
140140

141141

142-
def fetch_tracking_example_data(save_directory: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
142+
def fetch_tracking_example_data(save_directory: Union[str, os.PathLike]) -> str:
143143
"""Download the sample data for the tracking annotator.
144144
145145
This data is the cell tracking challenge dataset DIC-C2DH-HeLa.
@@ -171,3 +171,32 @@ def fetch_tracking_example_data(save_directory: Union[str, os.PathLike]) -> Unio
171171
cell_tracking_dir = save_directory.joinpath(f"{fname}.unzip", "DIC-C2DH-HeLa", "01")
172172
assert os.path.exists(cell_tracking_dir)
173173
return str(cell_tracking_dir)
174+
175+
176+
def fetch_tracking_segmentation_data(save_directory: Union[str, os.PathLike]) -> str:
177+
"""Download groundtruth segmentation for the tracking example data.
178+
179+
This downloads the groundtruth segmentation for the image data from `fetch_tracking_example_data`.
180+
181+
Args:
182+
save_directory: Root folder to save the downloaded data.
183+
Returns:
184+
The folder that contains the downloaded data.
185+
"""
186+
save_directory = Path(save_directory)
187+
os.makedirs(save_directory, exist_ok=True)
188+
print("Example data directory is:", save_directory.resolve())
189+
unpack_filenames = [os.path.join("masks", f"mask_{str(i).zfill(4)}.tif") for i in range(84)]
190+
unpack = pooch.Unzip(members=unpack_filenames)
191+
fname = "hela-ctc-01-gt.zip"
192+
pooch.retrieve(
193+
url="https://owncloud.gwdg.de/index.php/s/AWxQMblxwR99OjC/download",
194+
known_hash="c0644d8ebe1390fb60125560ba15aa2342caf44f50ff0667a0318ea0ac6c958b",
195+
fname=fname,
196+
path=save_directory,
197+
progressbar=True,
198+
processor=unpack,
199+
)
200+
cell_tracking_dir = save_directory.joinpath(f"{fname}.unzip", "masks")
201+
assert os.path.exists(cell_tracking_dir)
202+
return str(cell_tracking_dir)

micro_sam/util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def get_custom_sam_model(
191191
custom_pickle = pickle
192192
custom_pickle.Unpickler = _CustomUnpickler
193193

194-
device = "cuda" if torch.cuda.is_available() else "cpu"
194+
if device is None:
195+
device = "cuda" if torch.cuda.is_available() else "cpu"
195196
sam = sam_model_registry[model_type]()
196197

197198
# load the model state, ignoring any attributes that can't be found by pickle
@@ -230,7 +231,9 @@ def export_custom_sam_model(
230231
model_type: The SegmentAnything model type to use (vit_h, vit_b or vit_l).
231232
save_path: Where to save the exported model.
232233
"""
233-
_, state = get_custom_sam_model(checkpoint_path, model_type=model_type, return_state=True)
234+
_, state = get_custom_sam_model(
235+
checkpoint_path, model_type=model_type, return_state=True, device=torch.device("cpu"),
236+
)
234237
model_state = state["model_state"]
235238
prefix = "sam."
236239
model_state = OrderedDict(

0 commit comments

Comments
 (0)