Skip to content

Commit 16bd67a

Browse files
Update generalist experiments
1 parent 34814bc commit 16bd67a

File tree

6 files changed

+201
-3
lines changed

6 files changed

+201
-3
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import argparse
2+
import os
3+
from glob import glob
4+
from subprocess import run
5+
6+
import imageio.v3 as imageio
7+
8+
from tqdm import tqdm
9+
10+
DATA_ROOT = "/scratch/projects/nim00007/sam/datasets"
11+
EXP_ROOT = "/scratch/projects/nim00007/sam/experiments/cellpose"
12+
13+
DATASETS = (
14+
"covid-if",
15+
"deepbacs",
16+
"hpa",
17+
"livecell",
18+
"lizard",
19+
"mouse-embryo",
20+
"plantseg-ovules",
21+
"plantseg-root",
22+
"tissuenet",
23+
)
24+
25+
26+
def load_cellpose_model():
27+
from cellpose import models
28+
29+
device, gpu = models.assign_device(True, True)
30+
model = models.Cellpose(gpu=gpu, model_type="cyto", device=device)
31+
return model
32+
33+
34+
def run_cellpose_segmentation(datasets, job_id):
35+
dataset = datasets[job_id]
36+
experiment_folder = os.path.join(EXP_ROOT, dataset)
37+
38+
prediction_folder = os.path.join(experiment_folder, "predictions")
39+
os.makedirs(prediction_folder, exist_ok=True)
40+
41+
image_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, "test", "image*.tif")))
42+
model = load_cellpose_model()
43+
44+
for path in tqdm(image_paths, desc=f"Segmenting {dataset} with cellpose"):
45+
fname = os.path.basename(path)
46+
out_path = os.path.join(prediction_folder, fname)
47+
if os.path.exists(out_path):
48+
continue
49+
image = imageio.imread(path)
50+
if image.ndim == 3:
51+
assert image.shape[-1] == 3
52+
image = image.mean(axis=-1)
53+
assert image.ndim == 2
54+
seg = model.eval(image, diameter=None, flow_threshold=None, channels=[0, 0])[0]
55+
assert seg.shape == image.shape
56+
imageio.imwrite(out_path, seg, compression=5)
57+
58+
59+
def submit_array_job(datasets):
60+
n_datasets = len(datasets)
61+
cmd = ["sbatch", "-a", f"0-{n_datasets-1}", "cellpose_baseline.sbatch"]
62+
run(cmd)
63+
64+
65+
def evaluate_dataset(dataset):
66+
from micro_sam.evaluation.evaluation import run_evaluation
67+
68+
gt_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, "test", "label*.tif")))
69+
experiment_folder = os.path.join(EXP_ROOT, dataset)
70+
pred_paths = sorted(glob(os.path.join(experiment_folder, "predictions", "*.tif")))
71+
assert len(gt_paths) == len(pred_paths), f"{len(gt_paths)}, {len(pred_paths)}"
72+
result_path = os.path.join(experiment_folder, "cellpose.csv")
73+
run_evaluation(gt_paths, pred_paths, result_path)
74+
75+
76+
def evaluate_segmentations(datasets):
77+
for dataset in datasets:
78+
# we skip livecell, which has already been processed by cellpose
79+
if dataset == "livecell":
80+
continue
81+
evaluate_dataset(dataset)
82+
83+
84+
def check_results(datasets):
85+
for ds in datasets:
86+
# we skip livecell, which has already been processed by cellpose
87+
if ds == "livecell":
88+
continue
89+
result_path = os.path.join(EXP_ROOT, ds, "cellpose.csv")
90+
if not os.path.exists(result_path):
91+
print("Cellpose results missing for", ds)
92+
print("All checks passed")
93+
94+
95+
def main():
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument("--segment", "-s", action="store_true")
98+
parser.add_argument("--evaluate", "-e", action="store_true")
99+
parser.add_argument("--check", "-c", action="store_true")
100+
parser.add_argument("--datasets", nargs="+")
101+
args = parser.parse_args()
102+
103+
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)
104+
105+
if args.datasets is None:
106+
datasets = DATASETS
107+
else:
108+
datasets = args.datasets
109+
assert all(ds in DATASETS for ds in datasets)
110+
111+
if job_id is not None:
112+
run_cellpose_segmentation(datasets, int(job_id))
113+
elif args.segment:
114+
submit_array_job(datasets)
115+
elif args.evaluate:
116+
evaluate_segmentations(datasets)
117+
elif args.check:
118+
check_results(datasets)
119+
else:
120+
raise ValueError("Doing nothing")
121+
122+
123+
if __name__ == "__main__":
124+
main()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#! /bin/bash
2+
#SBATCH -c 4
3+
#SBATCH --mem 48G
4+
#SBATCH -t 300
5+
#SBATCH -p grete:shared
6+
#SBATCH -G A100:1
7+
#SBATCH -A nim00007
8+
9+
source activate cellpose
10+
python cellpose_baseline.py $@

finetuning/generalists/compile_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def compile_em():
5454
)
5555

5656

57-
# TODO
5857
def compile_lm():
5958
compile_results(
6059
["vit_h", "vit_h_lm", "vit_b", "vit_b_lm"],
@@ -64,7 +63,8 @@ def compile_lm():
6463

6564

6665
def main():
67-
compile_em()
66+
# compile_em()
67+
compile_lm()
6868

6969

7070
if __name__ == "__main__":

finetuning/generalists/evaluate_generalist.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
# Generalist EM models
1717
"vit_b_em": "/scratch/projects/nim00007/sam/models/EM/generalist/v2/vit_b/best.pt",
1818
"vit_h_em": "/scratch/projects/nim00007/sam/models/EM/generalist/v2/vit_h/best.pt",
19+
# Specialist Models (we don't add livecell, because these results are all computed already)
20+
"vit_b_tissuenet": "/scratch/projects/nim00007/sam/models/LM/TissueNet/vit_b/best.pt",
21+
"vit_h_tissuenet": "/scratch/projects/nim00007/sam/models/LM/TissueNet/vit_h/best.pt",
22+
"vit_b_deepbacs": "/scratch/projects/nim00007/sam/models/LM/DeepBacs/vit_b/best.pt",
23+
"vit_h_deepbacs": "/scratch/projects/nim00007/sam/models/LM/DeepBacs/vit_h/best.pt",
1924
}
2025

2126

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
3+
import imageio.v3 as imageio
4+
import micro_sam.evaluation.model_comparison as comparison
5+
import torch_em
6+
7+
from util import get_data_paths, EM_DATASETS
8+
9+
OUTPUT_ROOT = "/scratch-grete/projects/nim00007/sam/experiments/model_comparison"
10+
11+
12+
def _get_patch_shape(path):
13+
im_shape = imageio.imread(path).shape[:2]
14+
patch_shape = tuple(min(sh, 512) for sh in im_shape)
15+
return patch_shape
16+
17+
18+
def get_loader(dataset):
19+
image_paths, gt_paths = get_data_paths(dataset, split="test")
20+
image_paths, gt_paths = image_paths[:100], gt_paths[:100]
21+
22+
label_transform = torch_em.transform.label.connected_components
23+
loader = torch_em.default_segmentation_loader(
24+
image_paths, None, gt_paths, None,
25+
batch_size=1, patch_shape=_get_patch_shape(image_paths[0]),
26+
shuffle=True, n_samples=25, label_transform=label_transform,
27+
)
28+
return loader
29+
30+
31+
def generate_comparison_for_dataset(dataset, model1, model2):
32+
output_folder = os.path.join(OUTPUT_ROOT, dataset)
33+
if os.path.exists(output_folder):
34+
return
35+
print("Generate model comparison data for", dataset)
36+
loader = get_loader(dataset)
37+
comparison.generate_data_for_model_comparison(loader, output_folder, model1, model2, n_samples=25)
38+
39+
40+
# TODO
41+
def create_comparison_images():
42+
pass
43+
44+
45+
def generate_comparison_em():
46+
model1 = "vit_h"
47+
model2 = "vit_h_em"
48+
for dataset in EM_DATASETS:
49+
generate_comparison_for_dataset(dataset, model1, model2)
50+
create_comparison_images()
51+
52+
53+
def main():
54+
# generate_comparison_lm()
55+
generate_comparison_em()
56+
57+
58+
if __name__ == "__main__":
59+
main()

micro_sam/evaluation/model_comparison.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1,
8383

8484

8585
def generate_data_for_model_comparison(
86-
loader: torch.utils.DataLoader,
86+
loader: torch.utils.data.DataLoader,
8787
output_folder: Union[str, os.PathLike],
8888
model_type1: str,
8989
model_type2: str,

0 commit comments

Comments
 (0)