Skip to content

Commit e623e58

Browse files
2 parents 47b73f4 + 69f3c01 commit e623e58

File tree

8 files changed

+1278
-22
lines changed

8 files changed

+1278
-22
lines changed

micro_sam/automatic_segmentation.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def automatic_instance_segmentation(
8787
embedding_path: The path where the embeddings are cached already / will be saved.
8888
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
8989
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
90-
ndim: The dimensionality of the data.
90+
ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
91+
If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
9192
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
9293
halo: Overlap of the tiles for tiled prediction.
9394
verbose: Verbosity flag.
@@ -102,21 +103,12 @@ def automatic_instance_segmentation(
102103
else:
103104
image_data = util.load_image_data(input_path, key)
104105

105-
if ndim == 3 or image_data.ndim == 3:
106-
if image_data.ndim != 3:
107-
raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'")
106+
ndim = image_data.ndim if ndim is None else ndim
107+
108+
if ndim == 2:
109+
if image_data.ndim != 2 or image_data.shape[-1] != 3:
110+
raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
108111

109-
instances = automatic_3d_segmentation(
110-
volume=image_data,
111-
predictor=predictor,
112-
segmentor=segmenter,
113-
embedding_path=embedding_path,
114-
tile_shape=tile_shape,
115-
halo=halo,
116-
verbose=verbose,
117-
**generate_kwargs
118-
)
119-
else:
120112
# Precompute the image embeddings.
121113
image_embeddings = util.precompute_image_embeddings(
122114
predictor=predictor,
@@ -142,6 +134,20 @@ def automatic_instance_segmentation(
142134
instances = np.zeros(this_shape, dtype="uint32")
143135
else:
144136
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
137+
else:
138+
if image_data.ndim != 3 or image_data.shape[-1] != 3:
139+
raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
140+
141+
instances = automatic_3d_segmentation(
142+
volume=image_data,
143+
predictor=predictor,
144+
segmentor=segmenter,
145+
embedding_path=embedding_path,
146+
tile_shape=tile_shape,
147+
halo=halo,
148+
verbose=verbose,
149+
**generate_kwargs
150+
)
145151

146152
if output_path is not None:
147153
# Save the instance segmentation

micro_sam/training/util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ def __call__(self, x, y):
246246
#
247247

248248

249+
def normalize_to_8bit(raw):
250+
raw = normalize(raw) * 255
251+
return raw
252+
253+
249254
class ResizeRawTrafo:
250255
def __init__(self, desired_shape, do_rescaling=False, padding="constant"):
251256
self.desired_shape = desired_shape

notebooks/sam_finetuning.ipynb

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@
232232
"import torch\n",
233233
"\n",
234234
"import torch_em\n",
235-
"from torch_em.model import UNETR\n",
236235
"from torch_em.util.debug import check_loader\n",
237236
"from torch_em.data import MinInstanceSampler\n",
238237
"from torch_em.loss import DiceBasedDistanceLoss\n",
@@ -575,7 +574,7 @@
575574
"# Here, we load image data and labels from the two folders with tif images that were downloaded by the example data functionality,\n",
576575
"# by specifying `raw_key` and `label_key` as `*.tif`.\n",
577576
"# This means all images in the respective folders that end with .tif will be loaded.\n",
578-
"# The function supports many other file formats. For example, if you have tif stacks with multiple slices instead of multiple tif images in a foldder,\n",
577+
"# The function supports many other file formats. For example, if you have tif stacks with multiple slices instead of multiple tif images in a folder,\n",
579578
"# then you can pass raw_key=label_key=None.\n",
580579
"# For more information, here is the documentation: https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/README.md\n",
581580
"# And here is a tutorial on creating dataloaders using 'torch-em': https://github.com/constantinpape/torch-em/blob/main/notebooks/tutorial_create_dataloaders.ipynb\n",
@@ -1096,18 +1095,15 @@
10961095
"assert os.path.exists(best_checkpoint), \"Please train the model first to run inference on the finetuned model.\"\n",
10971096
"assert train_instance_segmentation is True, \"Oops. You didn't opt for finetuning using the decoder-based automatic instance segmentation.\"\n",
10981097
"\n",
1099-
"# # Let's check the first 5 images. Feel free to comment out the line below to run inference on all images.\n",
1098+
"# Let's check the first 5 images. Feel free to comment out the line below to run inference on all images.\n",
11001099
"image_paths = image_paths[:5]\n",
11011100
"\n",
11021101
"for image_path in image_paths:\n",
11031102
" image = imageio.imread(image_path)\n",
11041103
" \n",
11051104
" # Predicted instances\n",
11061105
" prediction = run_automatic_instance_segmentation(\n",
1107-
" image=image,\n",
1108-
" checkpoint_path=best_checkpoint,\n",
1109-
" model_type=model_type,\n",
1110-
" device=device\n",
1106+
" image=image, checkpoint_path=best_checkpoint, model_type=model_type, device=device\n",
11111107
" )\n",
11121108
"\n",
11131109
" # Visualize the predictions\n",

workshops/README.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Hands-On Analysis using `micro-sam`
2+
3+
## Upcoming Workshops:
4+
1. I2K 2024 (Milan, Italy)
5+
2. Virtual I2K 2024 (Online)
6+
7+
## Introduction
8+
9+
In this document, we walk you through different steps involved to participate in hands-on image annotation experiments our tool.
10+
11+
- Here is our [official documentation](https://computational-cell-analytics.github.io/micro-sam/) for detailed explanation of our tools, library and the finetuned models.
12+
- Here is the playlist for our [tutorial videos](https://youtube.com/playlist?list=PLwYZXQJ3f36GQPpKCrSbHjGiH39X4XjSO&si=3q-cIRD6KuoZFmAM) hosted on YouTube, elaborating in detail on the features of our tools.
13+
14+
## Steps:
15+
16+
### Step 1: Download the Datasets
17+
18+
- We provide the script `download_datasets.py` for automatic download of datasets to be used for interactive annotation using `micro-sam`.
19+
- You can run the script as follows:
20+
```bash
21+
$ python download_datasets.py -i <DATA_DIRECTORY> -d <DATASET_NAME>
22+
```
23+
where, `DATA_DIRECTORY` is the filepath to the directory where the datasets will be downloaded, and `DATASET_NAME` is the name of the dataset (run `python download_datasets.py -h` in the terminal for more details).
24+
25+
> NOTE: We have chosen a) subset of the CellPose `cyto` dataset, b) one volume from the EmbedSeg `Mouse-Skull-Nuclei-CBG` dataset from the train split (namely, `X1.tif`), c) one volume from the Platynereis `Membrane` dataset from the train split (namely, `train_data_membrane_02.n5`) and d) the entire `HPA` dataset for the following tasks in `micro-sam`.
26+
27+
### Step 2: Download the Precomputed Embeddings
28+
29+
- We provide the script `download_embeddings.py` for automatic download of precompute image embeddings for volumetric data to be used for interactive annotation using `micro-sam`.
30+
- You can run the script as follows:
31+
32+
```bash
33+
$ python download_embeddings -e <EMBEDDING_DIRECTORY> -d <DATASET_NAME>
34+
```
35+
where, `EMBEDDING_DIRECTORY` is the filepath to the directory where the precomputed image embeddings will be downloaded, and `DATASET_NAME` is the name of the dataset (run `python download_embeddings.py -h` in the terminal for more details).
36+
37+
### Additional Section: Precompute the Embeddings Yourself!
38+
39+
Here is an example guide to precompute the image embeddings (eg. for volumetric data).
40+
41+
#### EmbedSeg
42+
43+
```bash
44+
$ micro_sam.precompute_embeddings -i data/embedseg/Mouse-Skull-Nuclei-CBG/train/images/X1.tif # Filepath where inputs are stored.
45+
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm').
46+
-e embeddings/embedseg/vit_b/embedseg_Mouse-Skull-Nuclei-CBG_train_X1 # Filepath where computed embeddings will be cached.
47+
```
48+
49+
#### Platynereis
50+
51+
```bash
52+
$ micro_sam.precompute_embeddings -i data/platynereis/membrane/train_data_membrane_02.n5 # Filepath where inputs are stored.
53+
-k volumes/raw/s1 # Key to access the data group in container-style data structures.
54+
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_em_organelles').
55+
-e embeddings/platynereis/vit_b/platynereis_train_data_membrane_02 # Filepath where computed embeddings will be cached.
56+
```
57+
58+
### Step 3: Run the `micro-sam` Annotators (WIP)
59+
60+
Run the `micro-sam` annotators with the following scripts:
61+
62+
We recommend using the napari GUI for the interactive annotation. You can use the widget to specify all the essential parameters (eg. the choice of model, the filepath to the precomputed embeddings, etc).
63+
64+
TODO: add more details here.
65+
66+
There is another option to use `micro-sam`'s CLI to start our annotator tools.
67+
68+
#### 2D Annotator (Cell Segmentation in Light Microscopy):
69+
70+
```bash
71+
$ micro_sam.annotator_2d -i data/cellpose/cyto/test/... # Filepath where the 2d image is stored.
72+
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm')
73+
[OPTIONAL] -e embeddings/cellpose/vit_b/... # Filepath where the computed embeddings will be cached (you can choose to not pass it to compute the embeddings on-the-fly).
74+
```
75+
76+
#### 3D Annotator (EmbedSeg - Nuclei Segmentation in Light Microscopy):
77+
78+
```bash
79+
$ micro_sam.annotator_3d -i data/embedseg/Mouse-Skull-Nuclei-CBG/train/images/X1.tif # Filepath where the 3d volume is stored.
80+
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm')
81+
-e embeddings/embedseg/vit_b/embedseg_Mouse-Skull-Nuclei-CBG_train_X1.zarr # Filepath where the computed embeddings will be cached (we RECOMMEND to provide paths to the downloaded embeddings OR you can choose to not pass it to compute the embeddings on-the-fly).
82+
```
83+
84+
#### 3D Annotator (Platynereis - Membrane Segmentation in Electron Microscopy):
85+
86+
```bash
87+
$ micro_sam.annotator_3d -i data/platynereis/membrane/train_data_membrane_02.n5 # Filepath where the 2d image is stored.
88+
-k volumes/raw/s1 # Key to access the data group in container-style data structures.
89+
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_em_organelles')
90+
-e embeddings/platynereis/vit_b/... # Filepath where the computed embeddings will be cached (we RECOMMEND to provide paths to the downloaded embeddings OR you can choose to not pass it to compute the embeddings on-the-fly).
91+
```
92+
93+
#### Image Series Annotator (Multiple Light Microscopy 2D Images for Cell Segmentation):
94+
95+
```bash
96+
$ micro_sam.image_series_annotator -i ...
97+
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm')
98+
```
99+
100+
### Step 4: Finetune Segment Anything on Microscopy Images (WIP)
101+
102+
- We provide a notebook `finetune_sam.ipynb` / `finetune_sam.py` for finetuning Segment Anything Model for cell segmentation in confocal microscopy images.

workshops/download_datasets.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
from glob import glob
3+
from natsort import natsorted
4+
5+
from torch_em.data import datasets
6+
from torch_em.util.image import load_data
7+
8+
9+
def _download_sample_data(path, data_dir, url, checksum, download):
10+
if os.path.exists(data_dir):
11+
return
12+
13+
os.makedirs(path, exist_ok=True)
14+
15+
zip_path = os.path.join(path, "data.zip")
16+
datasets.util.download_source(path=zip_path, url=url, download=download, checksum=checksum)
17+
datasets.util.unzip(zip_path=zip_path, dst=path)
18+
19+
20+
def _get_cellpose_sample_data_paths(path, download):
21+
data_dir = os.path.join(path, "cellpose", "cyto", "test")
22+
23+
url = "https://owncloud.gwdg.de/index.php/s/slIxlmsglaz0HBE/download"
24+
checksum = "4d1ce7afa6417d051b93d6db37675abc60afe68daf2a4a5db0c787d04583ce8a"
25+
26+
_download_sample_data(path, data_dir, url, checksum, download)
27+
28+
raw_paths = natsorted(glob(os.path.join(data_dir, "*_img.png")))
29+
label_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png")))
30+
31+
return raw_paths, label_paths
32+
33+
34+
def _get_hpa_data_paths(path, split, download):
35+
urls = [
36+
"https://owncloud.gwdg.de/index.php/s/zp1Fmm4zEtLuhy4/download", # train
37+
"https://owncloud.gwdg.de/index.php/s/yV7LhGbGfvFGRBE/download", # val
38+
"https://owncloud.gwdg.de/index.php/s/8tLY5jPmpw37beM/download", # test
39+
]
40+
checksums = [
41+
"6e5f3ec6b0d505511bea752adaf35529f6b9bb9e7729ad3bdd90ffe5b2d302ab", # train
42+
"4d7a4188cc3d3877b3cf1fbad5f714ced9af4e389801e2136623eac2fde78e9c", # val
43+
"8963ff47cdef95cefabb8941f33a3916258d19d10f532a209bab849d07f9abfe", # test
44+
]
45+
splits = ["train", "val", "test"]
46+
assert split in splits, f"'{split}' is not a valid split."
47+
48+
for url, checksum, _split in zip(urls, checksums, splits):
49+
data_dir = os.path.join(path, _split)
50+
_download_sample_data(path, data_dir, url, checksum, download)
51+
52+
raw_paths = natsorted(glob(os.path.join(path, split, "images", "*.tif")))
53+
54+
if split == "test": # The 'test' split for HPA does not have labels.
55+
return raw_paths, None
56+
else:
57+
label_paths = natsorted(glob(os.path.join(path, split, "labels", "*.tif")))
58+
return raw_paths, label_paths
59+
60+
61+
def _get_dataset_paths(path, dataset_name, view=False):
62+
dataset_paths = {
63+
# 2d LM dataset for cell segmentation
64+
"cellpose": lambda: _get_cellpose_sample_data_paths(path=os.path.join(path, "cellpose"), download=True),
65+
"hpa": lambda: _get_hpa_data_paths(path=os.path.join(path, "hpa"), download=True, split="train"),
66+
# 3d LM dataset for nuclei segmentation
67+
"embedseg": lambda: datasets.embedseg_data.get_embedseg_paths(
68+
path=os.path.join(path, "embedseg"), name="Mouse-Skull-Nuclei-CBG", split="train", download=True,
69+
),
70+
# 3d EM dataset for membrane segmentation
71+
"platynereis": lambda: datasets.platynereis.get_platynereis_paths(
72+
path=os.path.join(path, "platynereis"), sample_ids=None, name="cells", download=True,
73+
),
74+
}
75+
76+
dataset_keys = {
77+
"cellpose": [None, None],
78+
"embedseg": [None, None],
79+
"platynereis": ["volumes/raw/s1", "volumes/labels/segmentation/s1"]
80+
}
81+
82+
if dataset_name is None: # Download all datasets.
83+
dataset_names = list(dataset_paths.keys())
84+
else: # Download specific datasets.
85+
dataset_names = [dataset_name]
86+
87+
for dname in dataset_names:
88+
if dname not in dataset_paths:
89+
raise ValueError(
90+
f"'{dname}' is not a supported dataset enabled for download. "
91+
f"Please choose from {list(dataset_paths.keys())}."
92+
)
93+
94+
paths = dataset_paths[dname]()
95+
print(f"'{dataset_name}' is download at {path}.")
96+
97+
if view:
98+
import napari
99+
100+
if isinstance(paths, tuple): # datasets with explicit raw and label paths
101+
raw_paths, label_paths = paths
102+
else:
103+
raw_paths = label_paths = paths
104+
105+
raw_key, label_key = dataset_keys[dname]
106+
for raw_path, label_path in zip(raw_paths, label_paths):
107+
raw = load_data(raw_path, raw_key)
108+
labels = load_data(label_path, label_key)
109+
110+
v = napari.Viewer()
111+
v.add_image(raw)
112+
v.add_labels(labels)
113+
napari.run()
114+
115+
break # comment this line out in case you would like to visualize all samples.
116+
117+
118+
def main():
119+
import argparse
120+
parser = argparse.ArgumentParser(description="Download the dataset necessary for the workshop.")
121+
parser.add_argument(
122+
"-i", "--input_path", type=str, default="./data",
123+
help="The filepath to the folder where the image data will be downloaded. "
124+
"By default, the data will be stored in your current working directory at './data'."
125+
)
126+
parser.add_argument(
127+
"-d", "--dataset_name", type=str, default=None,
128+
help="The choice of dataset you would like to download. By default, it downloads all the datasets. "
129+
"Optionally, you can choose to download either of 'cellpose', 'hpa', 'embedseg' or 'platynereis'."
130+
)
131+
parser.add_argument(
132+
"-v", "--view", action="store_true", help="Whether to view the downloaded data."
133+
)
134+
args = parser.parse_args()
135+
136+
_get_dataset_paths(path=args.input_path, dataset_name=args.dataset_name, view=args.view)
137+
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)