Skip to content

Commit 4150fd9

Browse files
Merge pull request #144 from computational-cell-analytics/generalist-experiments
Implement generalist model evaluation experiments
2 parents 0e19472 + f81b400 commit 4150fd9

23 files changed

+759
-399
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import argparse
2+
import os
3+
from glob import glob
4+
from subprocess import run
5+
6+
import imageio.v3 as imageio
7+
8+
from tqdm import tqdm
9+
10+
DATA_ROOT = "/scratch/projects/nim00007/sam/datasets"
11+
EXP_ROOT = "/scratch/projects/nim00007/sam/experiments/cellpose"
12+
13+
DATASETS = (
14+
"covid-if",
15+
"deepbacs",
16+
"hpa",
17+
"livecell",
18+
"lizard",
19+
"mouse-embryo",
20+
"plantseg-ovules",
21+
"plantseg-root",
22+
"tissuenet",
23+
)
24+
25+
26+
def load_cellpose_model():
27+
from cellpose import models
28+
29+
device, gpu = models.assign_device(True, True)
30+
model = models.Cellpose(gpu=gpu, model_type="cyto", device=device)
31+
return model
32+
33+
34+
def run_cellpose_segmentation(datasets, job_id):
35+
dataset = datasets[job_id]
36+
experiment_folder = os.path.join(EXP_ROOT, dataset)
37+
38+
prediction_folder = os.path.join(experiment_folder, "predictions")
39+
os.makedirs(prediction_folder, exist_ok=True)
40+
41+
image_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, "test", "image*.tif")))
42+
model = load_cellpose_model()
43+
44+
for path in tqdm(image_paths, desc=f"Segmenting {dataset} with cellpose"):
45+
fname = os.path.basename(path)
46+
out_path = os.path.join(prediction_folder, fname)
47+
if os.path.exists(out_path):
48+
continue
49+
image = imageio.imread(path)
50+
if image.ndim == 3:
51+
assert image.shape[-1] == 3
52+
image = image.mean(axis=-1)
53+
assert image.ndim == 2
54+
seg = model.eval(image, diameter=None, flow_threshold=None, channels=[0, 0])[0]
55+
assert seg.shape == image.shape
56+
imageio.imwrite(out_path, seg, compression=5)
57+
58+
59+
def submit_array_job(datasets):
60+
n_datasets = len(datasets)
61+
cmd = ["sbatch", "-a", f"0-{n_datasets-1}", "cellpose_baseline.sbatch"]
62+
run(cmd)
63+
64+
65+
def evaluate_dataset(dataset):
66+
from micro_sam.evaluation.evaluation import run_evaluation
67+
68+
gt_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, "test", "label*.tif")))
69+
experiment_folder = os.path.join(EXP_ROOT, dataset)
70+
pred_paths = sorted(glob(os.path.join(experiment_folder, "predictions", "*.tif")))
71+
assert len(gt_paths) == len(pred_paths), f"{len(gt_paths)}, {len(pred_paths)}"
72+
result_path = os.path.join(experiment_folder, "cellpose.csv")
73+
run_evaluation(gt_paths, pred_paths, result_path)
74+
75+
76+
def evaluate_segmentations(datasets):
77+
for dataset in datasets:
78+
# we skip livecell, which has already been processed by cellpose
79+
if dataset == "livecell":
80+
continue
81+
evaluate_dataset(dataset)
82+
83+
84+
def check_results(datasets):
85+
for ds in datasets:
86+
# we skip livecell, which has already been processed by cellpose
87+
if ds == "livecell":
88+
continue
89+
result_path = os.path.join(EXP_ROOT, ds, "cellpose.csv")
90+
if not os.path.exists(result_path):
91+
print("Cellpose results missing for", ds)
92+
print("All checks passed")
93+
94+
95+
def main():
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument("--segment", "-s", action="store_true")
98+
parser.add_argument("--evaluate", "-e", action="store_true")
99+
parser.add_argument("--check", "-c", action="store_true")
100+
parser.add_argument("--datasets", nargs="+")
101+
args = parser.parse_args()
102+
103+
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)
104+
105+
if args.datasets is None:
106+
datasets = DATASETS
107+
else:
108+
datasets = args.datasets
109+
assert all(ds in DATASETS for ds in datasets)
110+
111+
if job_id is not None:
112+
run_cellpose_segmentation(datasets, int(job_id))
113+
elif args.segment:
114+
submit_array_job(datasets)
115+
elif args.evaluate:
116+
evaluate_segmentations(datasets)
117+
elif args.check:
118+
check_results(datasets)
119+
else:
120+
raise ValueError("Doing nothing")
121+
122+
123+
if __name__ == "__main__":
124+
main()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#! /bin/bash
2+
#SBATCH -c 4
3+
#SBATCH --mem 48G
4+
#SBATCH -t 300
5+
#SBATCH -p grete:shared
6+
#SBATCH -G A100:1
7+
#SBATCH -A nim00007
8+
9+
source activate cellpose
10+
python cellpose_baseline.py $@
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import os
2+
from glob import glob
3+
4+
import pandas as pd
5+
6+
from evaluate_generalist import EXPERIMENT_ROOT
7+
from util import EM_DATASETS, LM_DATASETS
8+
9+
10+
def get_results(model, ds):
11+
res_folder = os.path.join(EXPERIMENT_ROOT, model, ds, "results")
12+
res_paths = sorted(glob(os.path.join(res_folder, "box", "*.csv"))) +\
13+
sorted(glob(os.path.join(res_folder, "points", "*.csv")))
14+
15+
amg_res = os.path.join(res_folder, "amg.csv")
16+
if os.path.exists(amg_res):
17+
res_paths.append(amg_res)
18+
19+
results = []
20+
for path in res_paths:
21+
prompt_res = pd.read_csv(path)
22+
prompt_name = os.path.splitext(os.path.relpath(path, res_folder))[0]
23+
prompt_res.insert(0, "prompt", [prompt_name])
24+
results.append(prompt_res)
25+
results = pd.concat(results)
26+
results.insert(0, "dataset", results.shape[0] * [ds])
27+
28+
return results
29+
30+
31+
def compile_results(models, datasets, out_path, load_results=False):
32+
results = []
33+
34+
for model in models:
35+
model_results = []
36+
37+
for ds in datasets:
38+
ds_results = get_results(model, ds)
39+
model_results.append(ds_results)
40+
41+
model_results = pd.concat(model_results)
42+
model_results.insert(0, "model", [model] * model_results.shape[0])
43+
results.append(model_results)
44+
45+
results = pd.concat(results)
46+
if load_results:
47+
assert os.path.exists(out_path)
48+
all_results = pd.read_csv(out_path)
49+
results = pd.concat([all_results, results])
50+
51+
results.to_csv(out_path, index=False)
52+
53+
54+
def compile_em():
55+
compile_results(
56+
["vit_h", "vit_h_em", "vit_b", "vit_b_em"],
57+
EM_DATASETS,
58+
os.path.join(EXPERIMENT_ROOT, "evaluation-em.csv")
59+
)
60+
61+
62+
def add_cellpose_results(datasets, out_path):
63+
cp_root = "/scratch/projects/nim00007/sam/experiments/cellpose"
64+
65+
results = []
66+
for dataset in datasets:
67+
if dataset == "livecell":
68+
continue
69+
res_path = os.path.join(cp_root, dataset, "cellpose.csv")
70+
ds_res = pd.read_csv(res_path)
71+
ds_res.insert(0, "prompt", ["cellpose"] * ds_res.shape[0])
72+
ds_res.insert(0, "dataset", [dataset] * ds_res.shape[0])
73+
results.append(ds_res)
74+
75+
results = pd.concat(results)
76+
results.insert(0, "model", ["cellpose"] * results.shape[0])
77+
78+
all_results = pd.read_csv(out_path)
79+
results = pd.concat([all_results, results])
80+
results.to_csv(out_path, index=False)
81+
82+
83+
def compile_lm():
84+
res_path = os.path.join(EXPERIMENT_ROOT, "evaluation-lm.csv")
85+
compile_results(
86+
["vit_h", "vit_h_lm", "vit_b", "vit_b_lm"], LM_DATASETS, res_path
87+
)
88+
89+
# add the deepbacs and tissuenet specialist results
90+
assert os.path.exists(res_path)
91+
compile_results(["vit_h_tissuenet", "vit_b_tissuenet"], ["tissuenet"], res_path, True)
92+
compile_results(["vit_h_deepbacs", "vit_b_deepbacs"], ["deepbacs"], res_path, True)
93+
94+
# add the cellpose results
95+
add_cellpose_results(LM_DATASETS, res_path)
96+
97+
98+
def main():
99+
# compile_em()
100+
compile_lm()
101+
102+
103+
if __name__ == "__main__":
104+
main()
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
import os
3+
from tqdm import tqdm
4+
import imageio.v2 as imageio
5+
import numpy as np
6+
7+
from torch_em.data import MinInstanceSampler
8+
from torch_em.transform.label import label_consecutive
9+
from torch_em.data.datasets import get_tissuenet_loader
10+
from torch_em.transform.raw import standardize, normalize_percentile
11+
12+
13+
def rgb_to_gray_transform(raw):
14+
raw = normalize_percentile(raw, axis=(1, 2))
15+
raw = np.mean(raw, axis=0)
16+
raw = standardize(raw)
17+
return raw
18+
19+
20+
def get_tissuenet_loaders(input_path):
21+
sampler = MinInstanceSampler()
22+
label_transform = label_consecutive
23+
raw_transform = rgb_to_gray_transform
24+
val_loader = get_tissuenet_loader(path=input_path, split="val", raw_channel="rgb", label_channel="cell",
25+
batch_size=1, patch_shape=(256, 256), num_workers=0,
26+
sampler=sampler, label_transform=label_transform, raw_transform=raw_transform)
27+
test_loader = get_tissuenet_loader(path=input_path, split="test", raw_channel="rgb", label_channel="cell",
28+
batch_size=1, patch_shape=(256, 256), num_workers=0,
29+
sampler=sampler, label_transform=label_transform, raw_transform=raw_transform)
30+
return val_loader, test_loader
31+
32+
33+
def extract_images(loader, out_folder):
34+
os.makedirs(out_folder, exist_ok=True)
35+
for i, (x, y) in tqdm(enumerate(loader), total=len(loader)):
36+
img_path = os.path.join(out_folder, "image_{:04d}.tif".format(i))
37+
gt_path = os.path.join(out_folder, "label_{:04d}.tif".format(i))
38+
39+
img = x.squeeze().detach().cpu().numpy()
40+
gt = y.squeeze().detach().cpu().numpy()
41+
42+
imageio.imwrite(img_path, img)
43+
imageio.imwrite(gt_path, gt)
44+
45+
46+
def main():
47+
val_loader, test_loader = get_tissuenet_loaders("/scratch-grete/projects/nim00007/data/tissuenet")
48+
print("Length of val loader is:", len(val_loader))
49+
print("Length of test loader is:", len(test_loader))
50+
51+
root_save_dir = "/scratch/projects/nim00007/sam/datasets/tissuenet"
52+
53+
# we use the val set for test because there are some issues with the test set
54+
# out_folder = os.path.join(root_save_dir, "test")
55+
# extract_images(val_loader, out_folder)
56+
57+
# we use the test folder for val and just use as many images as we can sample
58+
out_folder = os.path.join(root_save_dir, "val")
59+
extract_images(test_loader, out_folder)
60+
61+
62+
if __name__ == "__main__":
63+
main()

finetuning/generalists/data/precompute_prompts.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

0 commit comments

Comments
 (0)