Skip to content

Commit a9a8003

Browse files
Merge pull request #744 from computational-cell-analytics/dev
Changes for release 1.1
2 parents 35cb739 + 6330b72 commit a9a8003

File tree

92 files changed

+6153
-785
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+6153
-785
lines changed

.github/workflows/test.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ jobs:
2020
fail-fast: false
2121
matrix:
2222
os: [ubuntu-latest, windows-latest, macos-latest]
23-
python-version: ["3.11", "3.12"]
23+
# 3.12 currently not supported due to issues with nifty.
24+
# python-version: ["3.11", "3.12"]
25+
python-version: ["3.11"]
2426

2527
steps:
2628
- name: Checkout
@@ -30,6 +32,8 @@ jobs:
3032
uses: mamba-org/setup-micromamba@v1
3133
with:
3234
environment-file: environment_cpu.yaml
35+
create-args: >-
36+
python=${{ matrix.python-version }}
3337
3438
# Setup Qt libraries for GUI testing on Linux
3539
- uses: tlambert03/setup-qt-libs@v1

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,6 @@ cython_debug/
177177
# Torch-em stuff
178178
checkpoints/
179179
logs/
180+
181+
# "gpu_jobs" folder where slurm batch submission scripts are saved
182+
gpu_jobs/

development/check_3d_model.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import numpy as np
2+
import torch
3+
import micro_sam.util as util
4+
5+
from micro_sam.sam_3d_wrapper import get_3d_sam_model
6+
from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D
7+
8+
9+
def predict_3d_model():
10+
d_size = 8
11+
device = "cuda" if torch.cuda.is_available() else "cpu"
12+
sam_3d = get_3d_sam_model(device, d_size)
13+
14+
input_ = 255 * np.random.rand(1, d_size, 3, 1024, 1024).astype("float32")
15+
with torch.no_grad():
16+
input_ = torch.from_numpy(input_).to(device)
17+
out = sam_3d(input_, multimask_output=False, image_size=1024)
18+
print(out["masks"].shape)
19+
20+
21+
class DummyDataset(torch.utils.data.Dataset):
22+
def __init__(self, patch_shape, n_classes):
23+
self.patch_shape = patch_shape
24+
self.n_classes = n_classes
25+
26+
def __len__(self):
27+
return 5
28+
29+
def __getitem__(self, index):
30+
image_shape = (self.patch_shape[0], 3) + self.patch_shape[1:]
31+
x = np.random.rand(*image_shape).astype("float32")
32+
label_shape = (self.n_classes,) + self.patch_shape
33+
y = (np.random.rand(*label_shape) > 0.5).astype("float32")
34+
return x, y
35+
36+
37+
def get_loader(patch_shape, n_classes, batch_size):
38+
ds = DummyDataset(patch_shape, n_classes)
39+
loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
40+
loader.shuffle = True
41+
return loader
42+
43+
44+
# TODO: we are missing the resizing in the model, so currently this only supports 1024x1024
45+
def train_3d_model():
46+
from micro_sam.training.util import ConvertToSemanticSamInputs
47+
48+
d_size = 4
49+
n_classes = 5
50+
batch_size = 2
51+
image_size = 512
52+
53+
device = "cuda" if torch.cuda.is_available() else "cpu"
54+
sam_3d = get_3d_sam_model(device, n_classes=n_classes, image_size=image_size)
55+
56+
train_loader = get_loader((d_size, image_size, image_size), n_classes, batch_size)
57+
val_loader = get_loader((d_size, image_size, image_size), n_classes, batch_size)
58+
59+
optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5)
60+
61+
trainer = SemanticSamTrainer3D(
62+
name="test-sam",
63+
model=sam_3d,
64+
convert_inputs=ConvertToSemanticSamInputs(),
65+
num_classes=n_classes,
66+
train_loader=train_loader,
67+
val_loader=val_loader,
68+
optimizer=optimizer,
69+
device=device,
70+
compile_model=False,
71+
)
72+
trainer.fit(10)
73+
74+
75+
def main():
76+
# predict_3d_model()
77+
train_3d_model()
78+
79+
80+
if __name__ == "__main__":
81+
main()
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import napari
2+
from elf.io import open_file
3+
import h5py
4+
import os
5+
import torch
6+
import numpy as np
7+
8+
import micro_sam.sam_3d_wrapper as sam_3d
9+
import micro_sam.util as util
10+
# from micro_sam.segment_instances import (
11+
# segment_instances_from_embeddings,
12+
# segment_instances_sam,
13+
# segment_instances_from_embeddings_3d,
14+
# )
15+
from micro_sam import multi_dimensional_segmentation as mds
16+
from micro_sam.visualization import compute_pca
17+
INPUT_PATH_CLUSTER = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/upSTEM750_36859_J2_TS_SP_003_rec_2kb1dawbp_crop.h5"
18+
# EMBEDDINGS_PATH_CLUSTER = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/embedding-mito-3d.zarr"
19+
EMBEDDINGS_PATH_CLUSTER = "/scratch-grete/usr/nimlufre/"
20+
INPUT_PATH_LOCAL = "/home/freckmann15/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/upSTEM750_36859_J2_TS_SP_003_rec_2kb1dawbp_crop.h5"
21+
EMBEDDINGS_PATH_LOCAL = "/home/freckmann15/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/"
22+
INPUT_PATH = "/scratch-grete/projects/nim00007/data/mitochondria/moebius/volume_em/training_blocks_v1/4007_cutout_1.h5"
23+
EMBEDDINGS_PATH = "/scratch-grete/projects/nim00007/data/mitochondria/moebius/volume_em/training_blocks_v1/embedding-mito-3d.zarr"
24+
TIMESERIES_PATH = "../examples/data/DIC-C2DH-HeLa/train/01"
25+
EMBEDDINGS_TRACKING_PATH = "../examples/embeddings/embeddings-ctc.zarr"
26+
27+
# def cell_segmentation_3d() -> None:
28+
# with open_file(TIMESERIES_PATH, mode="r") as f:
29+
# timeseries = f["*.tif"][:50]
30+
31+
# predictor = util.get_sam_model()
32+
# image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH)
33+
34+
# seg = segment_instances_from_embeddings_3d(predictor, image_embeddings)
35+
36+
# v = napari.Viewer()
37+
# v.add_image(timeseries)
38+
# v.add_labels(seg)
39+
# napari.run()
40+
41+
42+
# def _get_dataset_and_reshape(path: str, key: str = "raw", shape: tuple = (32, 256, 256)) -> np.ndarray:
43+
44+
# with h5py.File(path, "r") as f:
45+
# # Check if the key exists in the file
46+
# if key not in f:
47+
# raise KeyError(f"Dataset with key '{key}' not found in file '{path}'.")
48+
49+
# # Load the dataset
50+
# dataset = f[key][...]
51+
52+
# # Reshape the dataset
53+
# if dataset.shape != shape:
54+
# try:
55+
# # Attempt to reshape the dataset to the desired shape
56+
# dataset = dataset.reshape(shape)
57+
# except ValueError:
58+
# raise ValueError(f"Failed to reshape dataset with key '{key}' to shape {shape}.")
59+
60+
# return dataset
61+
def get_dataset_cutout(path: str, key: str = "raw", shape: tuple = (32, 256, 256),
62+
start_index: tuple = (0, 0, 0)) -> np.ndarray:
63+
"""
64+
Loads a cutout from a dataset in an HDF5 file.
65+
66+
Args:
67+
path (str): Path to the HDF5 file.
68+
key (str, optional): Key of the dataset to load. Defaults to "raw".
69+
shape (tuple, optional): Desired shape of the cutout. Defaults to (32, 256, 256).
70+
start_index (tuple, optional): Starting index for the cutout within the dataset.
71+
Defaults to None, which selects a random starting point within valid bounds.
72+
73+
Returns:
74+
np.ndarray: The loaded cutout of the dataset with the specified shape.
75+
76+
Raises:
77+
KeyError: If the specified key is not found in the HDF5 file.
78+
ValueError: If the cutout shape exceeds the dataset dimensions or the starting index is invalid.
79+
"""
80+
81+
with h5py.File(path, "r") as f:
82+
83+
dataset = f[key]
84+
dataset_shape = dataset.shape
85+
print("original data shape", dataset_shape)
86+
87+
# Validate cutout shape
88+
if any(s > d for s, d in zip(shape, dataset_shape)):
89+
raise ValueError(f"Cutout shape {shape} exceeds dataset dimensions {dataset_shape}.")
90+
91+
# Generate random starting index if not provided
92+
if start_index is None:
93+
start_index = tuple(np.random.randint(0, dim - s + 1, size=len(shape)) for dim, s in zip(dataset_shape, shape))
94+
95+
# Calculate end index
96+
end_index = tuple(min(i + s, dim) for i, s, dim in zip(start_index, shape, dataset_shape))
97+
98+
# Load the cutout
99+
cutout = dataset[start_index[0]:end_index[0],
100+
start_index[1]:end_index[1],
101+
start_index[2]:end_index[2]]
102+
print("cutout data shape", cutout.shape)
103+
104+
return cutout
105+
106+
107+
def mito_segmentation_3d() -> None:
108+
patch_shape = (32, 256, 256)
109+
start_index = (10, 32, 64)
110+
data_slice = get_dataset_cutout(INPUT_PATH_LOCAL, shape=patch_shape) #start_index=start_index
111+
112+
device = "cuda" if torch.cuda.is_available() else "cpu"
113+
model_type = "vit_b"
114+
predictor, sam = util.get_sam_model(return_sam=True, model_type=model_type, device=device)
115+
116+
d_size = 3
117+
predictor3d = sam_3d.Predictor3D(sam, d_size)
118+
print(predictor3d)
119+
#breakpoint()
120+
predictor3d.model.forward(torch.from_numpy(data_slice), multimask_output=False, image_size=patch_shape)
121+
# output = predictor3d.model([data_slice], multimask_output=False)#image_size=patch_shape
122+
123+
# predictor3d._hash = util.models().registry[model_type]
124+
125+
# predictor3d.model_name = model_type
126+
127+
# image_embeddings = util.precompute_image_embeddings(predictor3d, volume, EMBEDDINGS_PATH_CLUSTER)
128+
# seg = util.segment_instances_from_embeddings_3d(predictor3d, image_embeddings)
129+
130+
# prediction_filename = os.path.join(EMBEDDINGS_PATH_CLUSTER, f"prediction_{INPUT_PATH_CLUSTER}.h5")
131+
# with h5py.File(prediction_filename, "w") as prediction_file:
132+
# prediction_file.create_dataset("prediction", data=seg)
133+
134+
# visualize
135+
# v = napari.Viewer()
136+
# v.add_image(volume)
137+
# v.add_labels(seg)
138+
# v.add_labels(seg_sam)
139+
# napari.run()
140+
141+
142+
143+
def main():
144+
# automatic segmentation for the data from Lucchi et al. (see 'sam_annotator_3d.py')
145+
# nucleus_segmentation(use_mws=True)
146+
mito_segmentation_3d()
147+
148+
# automatic segmentation for data from the cell tracking challenge (see 'sam_annotator_tracking.py')
149+
# cell_segmentation(use_mws=True)
150+
# cell_segmentation_3d()
151+
152+
153+
if __name__ == "__main__":
154+
main()

environment_cpu.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ channels:
44
- conda-forge
55
dependencies:
66
- cpuonly
7+
# This pin is necessary because later nifty versions have import errors on windows.
78
- nifty =1.2.1=*_4
89
- imagecodecs
910
- magicgui
10-
- napari <0.5
11+
- napari
1112
- pip
1213
- pooch
1314
- pyqt

environment_gpu.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ channels:
55
- conda-forge
66
dependencies:
77
- imagecodecs
8+
# This pin is necessary because later nifty versions have import errors on windows.
89
- nifty =1.2.1=*_4
910
- magicgui
10-
- napari <0.5
11+
- napari
1112
- pip
1213
- pooch
1314
- pyqt

examples/annotator_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def wholeslide_annotator(use_finetuned_model):
6565

6666
def main():
6767
# Whether to use the fine-tuned SAM model for light microscopy data.
68-
use_finetuned_model = False
68+
use_finetuned_model = True
6969

7070
# 2d annotator for livecell data
7171
livecell_annotator(use_finetuned_model)

finetuning/evaluation/evaluate_amg.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
from micro_sam.evaluation.evaluation import run_evaluation
44
from micro_sam.evaluation.inference import run_amg
55

6-
from util import get_paths # comment this and create a custom function with the same name to run amg on your data
7-
from util import get_pred_paths, get_default_arguments, VANILLA_MODELS
6+
from util import (
7+
get_paths, # comment this line out and create a custom function with the same name to run amg on your data
8+
get_pred_paths, get_default_arguments
9+
)
810

911

10-
def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder):
12+
def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs):
1113
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
1214
test_image_paths, _ = get_paths(dataset_name, split="test")
1315
prediction_folder = run_amg(
14-
checkpoint,
15-
model_type,
16-
experiment_folder,
17-
val_image_paths,
18-
val_gt_paths,
19-
test_image_paths
16+
checkpoint=checkpoint,
17+
model_type=model_type,
18+
experiment_folder=experiment_folder,
19+
val_image_paths=val_image_paths,
20+
val_gt_paths=val_gt_paths,
21+
test_image_paths=test_image_paths,
22+
peft_kwargs=peft_kwargs,
2023
)
2124
return prediction_folder
2225

@@ -32,12 +35,10 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder):
3235

3336
def main():
3437
args = get_default_arguments()
35-
if args.checkpoint is None:
36-
ckpt = VANILLA_MODELS[args.model]
37-
else:
38-
ckpt = args.checkpoint
39-
40-
prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder)
38+
peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module}
39+
prediction_folder = run_amg_inference(
40+
args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs
41+
)
4142
eval_amg(args.dataset, prediction_folder, args.experiment_folder)
4243

4344

finetuning/evaluation/evaluate_instance_segmentation.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,25 @@
33
from micro_sam.evaluation.evaluation import run_evaluation
44
from micro_sam.evaluation.inference import run_instance_segmentation_with_decoder
55

6-
from util import get_paths # comment this and create a custom function with the same name to run ais on your data
7-
from util import get_pred_paths, get_default_arguments
6+
from util import (
7+
get_paths, # comment this line out and create a custom function with the same name to run ais on your data
8+
get_pred_paths, get_default_arguments
9+
)
810

911

10-
def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder):
12+
def run_instance_segmentation_with_decoder_inference(
13+
dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs,
14+
):
1115
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
1216
test_image_paths, _ = get_paths(dataset_name, split="test")
1317
prediction_folder = run_instance_segmentation_with_decoder(
14-
checkpoint,
15-
model_type,
16-
experiment_folder,
17-
val_image_paths,
18-
val_gt_paths,
19-
test_image_paths
18+
checkpoint=checkpoint,
19+
model_type=model_type,
20+
experiment_folder=experiment_folder,
21+
val_image_paths=val_image_paths,
22+
val_gt_paths=val_gt_paths,
23+
test_image_paths=test_image_paths,
24+
peft_kwargs=peft_kwargs,
2025
)
2126
return prediction_folder
2227

@@ -32,9 +37,9 @@ def eval_instance_segmentation_with_decoder(dataset_name, prediction_folder, exp
3237

3338
def main():
3439
args = get_default_arguments()
35-
40+
peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module}
3641
prediction_folder = run_instance_segmentation_with_decoder_inference(
37-
args.dataset, args.model, args.checkpoint, args.experiment_folder
42+
args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs,
3843
)
3944
eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder)
4045

0 commit comments

Comments
 (0)