Skip to content

Commit 973c778

Browse files
Merge pull request #1091 from computational-cell-analytics/dev
Merge dev to master
2 parents 7a3e3b5 + b1562c9 commit 973c778

File tree

8 files changed

+328
-18
lines changed

8 files changed

+328
-18
lines changed

development/check_data_count.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import os
2+
from glob import glob
3+
4+
import numpy as np
5+
import imageio.v3 as imageio
6+
7+
from torch_em.data import datasets
8+
9+
from elf.io import open_file
10+
11+
12+
ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/micro_sam/data"
13+
14+
15+
def check_data_count(lm_version="v3"):
16+
image_counter, object_counter = 0, 0
17+
18+
# LIVECell data.
19+
image_paths, label_paths = datasets.light_microscopy.livecell.get_livecell_paths(
20+
path=os.path.join(ROOT, "livecell"), split="train",
21+
)
22+
image_counter += len(image_paths)
23+
object_counter += sum(
24+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
25+
)
26+
27+
print("LIVECell", image_counter, object_counter)
28+
29+
# DeepBacs data.
30+
image_dir, label_dir = datasets.light_microscopy.deepbacs.get_deepbacs_paths(
31+
path=os.path.join(ROOT, "deepbacs"), bac_type="mixed", split="train",
32+
)
33+
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
34+
label_paths = sorted(glob(os.path.join(label_dir, "*.tif")))
35+
36+
curr_image_counter = len(image_paths)
37+
curr_object_counter = sum(
38+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
39+
)
40+
41+
image_counter += curr_image_counter
42+
object_counter += curr_object_counter
43+
44+
print("DeepBacs", curr_image_counter, curr_object_counter)
45+
46+
# TissueNet data.
47+
sample_paths = datasets.light_microscopy.tissuenet.get_tissuenet_paths(
48+
path=os.path.join(ROOT, "tissuenet"), split="train",
49+
)
50+
curr_image_counter = len(sample_paths)
51+
curr_object_counter = sum(
52+
[len(np.unique(open_file(p)["labels/cell"])[1:]) for p in sample_paths]
53+
)
54+
55+
image_counter += curr_image_counter
56+
object_counter += curr_object_counter
57+
58+
print("TissueNet", curr_image_counter, curr_object_counter)
59+
60+
# PlantSeg (Root) data.
61+
volume_paths = datasets.light_microscopy.plantseg.get_plantseg_paths(
62+
path=os.path.join(ROOT, "plantseg"), name="root", split="train",
63+
)
64+
curr_image_counter, curr_object_counter = 0, 0
65+
for p in volume_paths:
66+
f = open_file(p)
67+
curr_image_counter += f["raw"].shape[0]
68+
curr_object_counter += sum(
69+
[len(np.unique(curr_label)[1:]) for curr_label in f["label"]]
70+
)
71+
72+
image_counter += curr_image_counter
73+
object_counter += curr_object_counter
74+
75+
print("PlantSeg (Root)", curr_image_counter, curr_object_counter)
76+
77+
# NeurIPS CellSeg data.
78+
image_paths, label_paths = datasets.light_microscopy.neurips_cell_seg.get_neurips_cellseg_paths(
79+
root=os.path.join(ROOT, "neurips_cellseg"), split="train",
80+
)
81+
curr_image_counter = len(image_paths)
82+
curr_object_counter = sum(
83+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
84+
)
85+
86+
image_counter += curr_image_counter
87+
object_counter += curr_object_counter
88+
89+
print("NeurIPS CellSeg", curr_image_counter, curr_object_counter)
90+
91+
# CTC data.
92+
curr_image_counter, curr_object_counter = 0, 0
93+
for dataset_name in datasets.ctc.CTC_CHECKSUMS["train"].keys():
94+
if dataset_name in ["Fluo-N2DH-GOWT1", "Fluo-N2DL-HeLa"]:
95+
continue
96+
97+
image_dirs, label_dirs = datasets.light_microscopy.ctc.get_ctc_segmentation_paths(
98+
path=os.path.join(ROOT, "ctc"), dataset_name=dataset_name,
99+
)
100+
image_paths = [p for d in image_dirs for p in sorted(glob(os.path.join(d, "*.tif")))]
101+
label_paths = [p for d in label_dirs for p in sorted(glob(os.path.join(d, "*.tif")))]
102+
103+
curr_image_counter += len(image_paths)
104+
curr_object_counter += sum(
105+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
106+
)
107+
108+
image_counter += curr_image_counter
109+
object_counter += curr_object_counter
110+
111+
print("CTC", curr_image_counter, curr_object_counter)
112+
113+
# DSB Nucleus data.
114+
image_paths, label_paths = datasets.light_microscopy.dsb.get_dsb_paths(
115+
path=os.path.join(ROOT, "dsb"), source="reduced", split="train",
116+
)
117+
curr_image_counter = len(image_paths)
118+
curr_object_counter = sum(
119+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
120+
)
121+
122+
image_counter += curr_image_counter
123+
object_counter += curr_object_counter
124+
125+
print("DSB Nucleus", curr_image_counter, curr_object_counter)
126+
127+
if lm_version == "v2":
128+
return image_counter, object_counter
129+
130+
# EmbedSeg data.
131+
curr_image_counter, curr_object_counter = 0, 0
132+
names = [
133+
"Mouse-Organoid-Cells-CBG", "Mouse-Skull-Nuclei-CBG", "Platynereis-ISH-Nuclei-CBG", "Platynereis-Nuclei-CBG",
134+
]
135+
for name in names:
136+
image_paths, label_paths = datasets.light_microscopy.embedseg_data.get_embedseg_paths(
137+
path=os.path.join(ROOT, "embedseg"), name=name, split="train",
138+
)
139+
curr_image_counter += sum(
140+
[imageio.imread(p).shape[0] for p in image_paths]
141+
)
142+
curr_object_counter += sum(
143+
[sum(len(np.unique(curr_label)[1:]) for curr_label in imageio.imread(p)) for p in label_paths]
144+
)
145+
146+
image_counter += curr_image_counter
147+
object_counter += curr_object_counter
148+
149+
print("EmbedSeg", curr_image_counter, curr_object_counter)
150+
151+
# CVZ Fluo data.
152+
curr_image_counter, curr_object_counter = 0, 0
153+
for stain_choice in ["cell", "dapi"]:
154+
image_paths, label_paths = datasets.light_microscopy.cvz_fluo.get_cvz_fluo_paths(
155+
path=os.path.join(ROOT, "cvz"), stain_choice=stain_choice,
156+
)
157+
curr_image_counter += len(image_paths)
158+
curr_object_counter += sum(
159+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
160+
)
161+
162+
image_counter += curr_image_counter
163+
object_counter += curr_object_counter
164+
165+
print("CVZ Fluo", curr_image_counter, curr_object_counter)
166+
167+
# DynamicNuclearNet data.
168+
sample_paths = datasets.light_microscopy.dynamicnuclearnet.get_dynamicnuclearnet_paths(
169+
path=os.path.join(ROOT, "dynamicnuclearnet"), split="train",
170+
)
171+
172+
curr_image_counter = len(sample_paths)
173+
curr_object_counter = sum(
174+
[len(np.unique(open_file(p)["labels"])[1:]) for p in sample_paths]
175+
)
176+
177+
image_counter += curr_image_counter
178+
object_counter += curr_object_counter
179+
180+
print("DynamicNuclearNet", curr_image_counter, curr_object_counter)
181+
182+
# CellPose data.
183+
image_paths, label_paths = datasets.light_microscopy.cellpose.get_cellpose_paths(
184+
path=os.path.join(ROOT, "cellpose"), split="train", choice="cyto",
185+
)
186+
curr_image_counter = len(image_paths)
187+
curr_object_counter = sum(
188+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
189+
)
190+
191+
image_counter += curr_image_counter
192+
object_counter += curr_object_counter
193+
194+
print("CellPose", curr_image_counter, curr_object_counter)
195+
196+
# OmniPose data.
197+
image_paths, label_paths = datasets.light_microscopy.omnipose.get_omnipose_paths(
198+
path=os.path.join(ROOT, "omnipose"), split="train",
199+
)
200+
curr_image_counter = len(image_paths)
201+
curr_object_counter = sum(
202+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
203+
)
204+
205+
image_counter += curr_image_counter
206+
object_counter += curr_object_counter
207+
208+
print("OmniPose", curr_image_counter, curr_object_counter)
209+
210+
# OrgaSegment data.
211+
image_paths, label_paths = datasets.light_microscopy.orgasegment.get_orgasegment_paths(
212+
path=os.path.join(ROOT, "orgasegment"), split="train",
213+
)
214+
curr_image_counter = len(image_paths)
215+
curr_object_counter = sum(
216+
[len(np.unique(imageio.imread(p))[1:]) for p in label_paths]
217+
)
218+
219+
image_counter += curr_image_counter
220+
object_counter += curr_object_counter
221+
222+
print("OrgaSegment", curr_image_counter, curr_object_counter)
223+
224+
return image_counter, object_counter
225+
226+
227+
def main():
228+
# image_counts, object_counts = check_data_count("v2")
229+
# print(f"v2 Model - Count of images: '{image_counts}'; and count of objects: '{object_counts}'")
230+
231+
image_counts, object_counts = check_data_count("v3")
232+
print(f"v3 and v4 Model - Count of images: '{image_counts}'; and count of objects: '{object_counts}'")
233+
234+
235+
if __name__ == "__main__":
236+
main()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
from glob import glob
3+
from natsort import natsorted
4+
5+
import torch
6+
7+
import micro_sam.training as sam_training
8+
from micro_sam.util import export_custom_sam_model
9+
10+
11+
def train_embl_alm_data(checkpoint_name):
12+
"""Training a MicroSAM model for https://github.com/computational-cell-analytics/micro-sam/issues/1084.
13+
"""
14+
# All hyperparameters for training.
15+
batch_size = 1
16+
patch_shape = (512, 512)
17+
n_objects_per_batch = 25
18+
device = torch.device("cuda")
19+
20+
# Get the filepaths to images and corresponding labels.
21+
image_paths = natsorted(glob(os.path.join(os.getcwd(), "data_same_size", "*.tif")))
22+
label_paths = natsorted(glob(os.path.join(os.getcwd(), "masks_same_size", "*.tif")))
23+
24+
# Next, prepare the dataloaders.
25+
kwargs = {
26+
"batch_size": batch_size,
27+
"patch_shape": patch_shape,
28+
"with_segmentation_decoder": True,
29+
"num_workers": 16,
30+
"shuffle": True,
31+
}
32+
33+
train_loader = sam_training.default_sam_loader(
34+
raw_paths=image_paths[:-5], raw_key=None, label_paths=label_paths[:-5], label_key=None, **kwargs,
35+
)
36+
val_loader = sam_training.default_sam_loader(
37+
raw_paths=image_paths[-5:], raw_key=None, label_paths=label_paths[-5:], label_key=None, **kwargs,
38+
)
39+
40+
# Run training.
41+
sam_training.train_sam(
42+
name=checkpoint_name,
43+
model_type="vit_b_lm",
44+
train_loader=train_loader,
45+
val_loader=val_loader,
46+
n_epochs=10,
47+
n_objects_per_batch=n_objects_per_batch,
48+
with_segmentation_decoder=True,
49+
device=device,
50+
)
51+
52+
53+
def main():
54+
checkpoint_name = "sam_embl_alm_fluo" # Name of the checkpoint, stored at "./checkpoints/<CHECKPOINT_NAME>"
55+
56+
train_embl_alm_data(checkpoint_name)
57+
58+
# Export the trained model.
59+
export_custom_sam_model(
60+
checkpoint_path=os.path.join("checkpoints", checkpoint_name, "best.pt"),
61+
model_type="vit_b",
62+
save_path="./finetuned_embl_alm_fluo_model.pth",
63+
)
64+
65+
66+
if __name__ == "__main__":
67+
main()

doc/band.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ BAND is a service offered by EMBL Heidelberg under the "The German Network for B
44
In order to use BAND and start `micro_sam` on it follow these steps:
55

66
## Start BAND
7-
- Go to https://bandv1.denbi.uni-tuebingen.de/ and click **Login**. If you have not used BAND before you will need to register for BAND. Currently you can only sign up via a Google account. NOTE: It takes a couple of seconds for the "Launch Desktops" window to appear.
7+
- Go to https://bandv1.denbi.uni-tuebingen.de/ (another site available at https://band.vm.fedcloud.eu/, choose either) and click **Login**. If you have not used BAND before you will need to register for BAND. Currently you can only sign up via a Google account. NOTE: It takes a couple of seconds for the "Launch Desktops" window to appear.
88
- Launch a BAND desktop with sufficient resources. It's particularly important to select a GPU. The settings from the image below are a good choice.
99
- Go to the desktop by clicking **GO TO DESKTOP** in the **Running Desktops** menu. See also the screenshot below.
1010

environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies:
2626
- torch_em >=0.7.10
2727
- tqdm
2828
- timm
29+
- trackastra
2930
- xarray
3031
- zarr
3132
- pip:

micro_sam/bioimageio/model_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def export_sam_model(
476476
source=Path(checkpoint_path),
477477
architecture=architecture,
478478
pytorch_version=spec.Version(torch.__version__),
479-
dependencies=spec.EnvironmentFileDescr(source=dependency_file),
479+
dependencies=spec.FileDescr(source=dependency_file),
480480
)
481481
)
482482

micro_sam/evaluation/instance_segmentation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,20 +229,20 @@ def run_instance_segmentation_grid_search(
229229
image = _load_image(image_path, image_key, roi=None if rois is None else rois[i])
230230
gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i])
231231

232+
if tiling_window_params is None:
233+
tiling_window_params = {}
234+
232235
if embedding_dir is None:
233236
embedding_path = None
237+
segmenter.initialize(image, **tiling_window_params)
238+
234239
else:
235240
assert predictor is not None
236241
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
237-
238-
if tiling_window_params is None:
239-
tiling_window_params = {}
240-
241-
image_embeddings = util.precompute_image_embeddings(
242-
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
243-
)
244-
245-
segmenter.initialize(image, image_embeddings, **tiling_window_params)
242+
image_embeddings = util.precompute_image_embeddings(
243+
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
244+
)
245+
segmenter.initialize(image, image_embeddings, **tiling_window_params)
246246

247247
_grid_search_iteration(
248248
segmenter, gs_combinations, gt, image_name,

micro_sam/instance_segmentation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,6 @@ def get_unetr(
805805
use_skip_connection=False,
806806
resize_input=True,
807807
use_conv_transpose=use_conv_transpose,
808-
809808
)
810809

811810
if decoder_state is not None:

0 commit comments

Comments
 (0)