Skip to content

Commit d99f261

Browse files
Merge pull request #123 from computational-cell-analytics/dev
Training and Evaluation Functionality
2 parents 94ad363 + 74e6f3a commit d99f261

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+4175
-60
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ __pycache__/
33
*.pth
44
*.tif
55
examples/data/*
6+
*.out

environment_cpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- pytorch
1212
- segment-anything
1313
- torchvision
14+
- torch_em >=0.5.1
1415
- tqdm
1516
# - pip:
1617
# - git+https://github.com/facebookresearch/segment-anything.git

environment_gpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies:
1212
- pytorch-cuda>=11.7 # you may need to update the cuda version to match your system
1313
- segment-anything
1414
- torchvision
15+
- torch_em >=0.5.1
1516
- tqdm
1617
# - pip:
1718
# - git+https://github.com/facebookresearch/segment-anything.git
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import h5py
2+
import micro_sam.sam_annotator as annotator
3+
from micro_sam.util import get_sam_model
4+
5+
# TODO add an example for the 2d annotator with a custom model
6+
7+
8+
def annotator_3d_with_custom_model():
9+
with h5py.File("./data/gut1_block_1.h5") as f:
10+
raw = f["raw"][:]
11+
12+
custom_model = "/home/pape/Work/data/models/sam/user-study/vit_h_nuclei_em_finetuned.pt"
13+
embedding_path = "./embeddings/nuclei3d-custom-vit-h.zarr"
14+
predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_h")
15+
annotator.annotator_3d(raw, embedding_path, predictor=predictor)
16+
17+
18+
def main():
19+
annotator_3d_with_custom_model()
20+
21+
22+
if __name__ == "__main__":
23+
main()

finetuning/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
checkpoints/
2+
logs/
3+
sam_embeddings/
4+
results/

finetuning/README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Segment Anything Finetuning
2+
3+
Preliminary examples for fine-tuning segment anything on custom datasets.
4+
5+
## LiveCELL
6+
7+
**Finetuning**
8+
9+
Run the script `livecell_finetuning.py` for fine-tuning a model on LiveCELL.
10+
11+
**Inference**
12+
13+
The script `livecell_inference.py` can be used to run inference on the test set. It supports different arguments for inference with different configurations.
14+
For example run
15+
```
16+
python livecell_inference.py -c checkpoints/livecell_sam/best.pt -m vit_b -e experiment -i /scratch/projects/nim00007/data/LiveCELL --points --positive 1 --negative 0
17+
```
18+
for inference with 1 positive point prompt and no negative point prompt (the prompts are derived from ground-truth).
19+
20+
The arguments `-c`, `-e` and `-i` specify where the checkpoint for the model is, where the predictions from the model and other experiment data will be saved, and where the input dataset (LiveCELL) is stored.
21+
22+
To run the default set of experiments from our publication use the command
23+
```
24+
python livecell_inference.py -c checkpoints/livecell_sam/best.pt -m vit_b -e experiment -i /scratch/projects/nim00007/data/LiveCELL -d --prompt_folder /scratch/projects/nim00007/sam/experiments/prompts/livecell
25+
```
26+
27+
Here `-d` automatically runs the evaluation for these settings:
28+
- `--points --positive 1 --negative 0` (using point prompts with a single positive point)
29+
- `--points --positive 2 --negative 4` (using point prompts with two positive points and four negative points)
30+
- `--points --positive 4 --negative 8` (using point prompts with four positive points and eight negative points)
31+
- `--box` (using box prompts)
32+
33+
In addition `--prompt_folder` specifies a folder with precomputed prompts. Using pre-computed prompts significantly speeds up the experiments and enables running them in a reproducible manner. (Without it the prompts will be recalculated each time.)
34+
35+
You can also evaluate the automatic instance segmentation functionality, by running
36+
```
37+
python livecell_inference.py -c checkpoints/livecell_sam/best.pt -m vit_b -e experiment -i /scratch/projects/nim00007/data/LiveCELL -a
38+
```
39+
40+
This will first perform a grid-search for the best parameters on a subset of the validation set and then run inference on the test set. This can take up to a day.
41+
42+
**Evaluation**
43+
44+
The script `livecell_evaluation.py` can then be used to evaluate the results from the inference runs.
45+
E.g. run the script like below to evaluate the previous predictions.
46+
```
47+
python livecell_evaluation.py -i /scratch/projects/nim00007/data/LiveCELL -e experiment
48+
```
49+
This will create a folder `experiment/results` with csv tables with the results per cell type and averaged over all images.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import argparse
2+
import os
3+
from glob import glob
4+
5+
import pickle
6+
from subprocess import run
7+
8+
import micro_sam.evaluation as evaluation
9+
from tqdm import tqdm
10+
11+
DATA_ROOT = "/scratch/projects/nim00007/sam/ood/LM"
12+
PROMPT_ROOT = "/scratch/projects/nim00007/sam/experiments/prompts"
13+
14+
15+
def get_paths(dataset):
16+
pattern = os.path.join(DATA_ROOT, dataset, "test", "labels_*.tif")
17+
paths = sorted(glob(pattern))
18+
assert len(paths) > 0, pattern
19+
return paths
20+
21+
22+
def precompute_setting(prompt_settings, dataset):
23+
gt_paths = get_paths(dataset)
24+
prompt_folder = os.path.join(PROMPT_ROOT, dataset)
25+
evaluation.precompute_all_prompts(gt_paths, prompt_folder, prompt_settings)
26+
27+
28+
def submit_array_job(prompt_settings, dataset):
29+
n_settings = len(prompt_settings)
30+
cmd = ["sbatch", "-a", f"0-{n_settings-1}", "precompute_prompts.sbatch", dataset]
31+
run(cmd)
32+
33+
34+
def check_settings(dataset, settings, expected_len):
35+
prompt_folder = os.path.join(PROMPT_ROOT, dataset)
36+
37+
def check_prompt_file(prompt_file):
38+
assert os.path.exists(prompt_file), prompt_file
39+
with open(prompt_file, "rb") as f:
40+
prompts = pickle.load(f)
41+
assert len(prompts) == expected_len, f"{len(prompts)}, {expected_len}"
42+
43+
for setting in tqdm(settings, desc="Check prompt files"):
44+
pos, neg = setting["n_positives"], setting["n_negatives"]
45+
prompt_file = os.path.join(prompt_folder, f"points-p{pos}-n{neg}.pkl")
46+
if pos == 0 and neg == 0:
47+
prompt_file = os.path.join(prompt_folder, "boxes.pkl")
48+
check_prompt_file(prompt_file)
49+
50+
print("All files checked!")
51+
52+
53+
def main():
54+
parser = argparse.ArgumentParser()
55+
parser.add_argument("dataset")
56+
parser.add_argument("-c", "--check", action="store_true")
57+
args = parser.parse_args()
58+
59+
# this will fail if the dataset is invalid
60+
gt_paths = get_paths(args.dataset)
61+
62+
settings = evaluation.default_experiment_settings()
63+
# we may use this as the point setting instead of p2-n4,
64+
# so we also precompute it
65+
settings.append(
66+
{"use_points": True, "use_boxes": False, "n_positives": 4, "n_negatives": 8}, # p4-n8
67+
)
68+
69+
if args.check:
70+
check_settings(args.dataset, settings, len(gt_paths))
71+
return
72+
73+
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)
74+
75+
if job_id is None: # this is the main script that submits slurm jobs
76+
submit_array_job(settings, args.dataset)
77+
else: # we're in a slurm job and precompute a setting
78+
job_id = int(job_id)
79+
this_settings = [settings[job_id]]
80+
precompute_setting(this_settings, args.dataset)
81+
82+
83+
if __name__ == "__main__":
84+
main()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#! /bin/bash
2+
#SBATCH -c 4
3+
#SBATCH --mem 48G
4+
#SBATCH -t 720
5+
#SBATCH -p grete:shared
6+
#SBATCH -G A100:1
7+
8+
source activate sam
9+
python precompute_prompts.py $@
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
import numpy as np
6+
7+
from sklearn.model_selection import train_test_split
8+
9+
ROOT = "/scratch-grete/projects/nim00007/data/deepbacs"
10+
11+
12+
def download_deepbacs():
13+
from torch_em.data.datasets import get_deepbacs_loader
14+
get_deepbacs_loader(ROOT, "train", bac_type="mixed", download=True, patch_shape=(256, 256), batch_size=1)
15+
get_deepbacs_loader(ROOT, "test", bac_type="mixed", download=True, patch_shape=(256, 256), batch_size=1)
16+
17+
18+
# old code from Anwai
19+
def get_deepbacs_test_images():
20+
root = ROOT
21+
output_root = "/scratch-grete/projects/nim00007/sam/ood/LM/deepbacs"
22+
23+
def write_split(images, labels, split):
24+
out_folder = os.path.join(output_root, split)
25+
os.makedirs(out_folder, exist_ok=True)
26+
for ii, (im, lab) in enumerate(zip(images, labels)):
27+
out_im = os.path.join(out_folder, f"image_{ii:04}.tif")
28+
out_lab = os.path.join(out_folder, f"labels_{ii:04}.tif")
29+
im, lab = imageio.imread(im), imageio.imread(lab)
30+
imageio.imwrite(out_im, im)
31+
imageio.imwrite(out_lab, lab)
32+
33+
root_imgs = glob(os.path.join(root, "mixed", "test", "source", "*"))
34+
root_gts = glob(os.path.join(root, "mixed", "test", "target", "*"))
35+
np.random.seed(0)
36+
37+
val_images = np.random.choice(root_imgs, size=5, replace=False).tolist()
38+
val_labels = [gt_p for gt_p in root_gts if os.path.basename(gt_p) in [os.path.basename(x) for x in val_images]]
39+
40+
test_images = [ip for ip in root_imgs if ip not in val_images]
41+
test_labels = [gp for gp in root_gts if gp not in val_labels]
42+
43+
write_split(val_images, val_labels, "val")
44+
write_split(test_images, test_labels, "test")
45+
46+
47+
# new simplified code
48+
def get_deepbacs_test_images_new():
49+
root = ROOT
50+
output_root = "/scratch-grete/projects/nim00007/sam/ood/LM/deepbacs"
51+
52+
def write_split(images, labels, split):
53+
out_folder = os.path.join(output_root, split)
54+
os.makedirs(out_folder, exist_ok=True)
55+
for ii, (im, lab) in enumerate(zip(images, labels)):
56+
out_im = os.path.join(out_folder, f"image_{ii:04}.tif")
57+
out_lab = os.path.join(out_folder, f"labels_{ii:04}.tif")
58+
im, lab = imageio.imread(im), imageio.imread(lab)
59+
imageio.imwrite(out_im, im)
60+
imageio.imwrite(out_lab, lab)
61+
62+
images = sorted(glob(os.path.join(root, "mixed", "test", "source", "*")))
63+
labels = sorted(glob(os.path.join(root, "mixed", "test", "target", "*")))
64+
65+
test_images, val_images, test_labels, val_labels = train_test_split(
66+
images, labels, test_size=0.15, random_state=42
67+
)
68+
69+
write_split(val_images, val_labels, "val")
70+
write_split(test_images, test_labels, "test")
71+
72+
73+
def main():
74+
# download_deepbacs()
75+
get_deepbacs_test_images_new()
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
6+
ROOT = "/scratch-grete/projects/nim00007/data/tissuenet"
7+
8+
9+
def get_tissuenet_images(split):
10+
assert split in ["val", "test"]
11+
val_set, test_set = glob(os.path.join(ROOT, "val", "*")), glob(os.path.join(ROOT, "test", "*"))
12+
if split == "val":
13+
return sorted(val_set)
14+
else:
15+
return sorted(test_set)
16+
17+
18+
# TODO
19+
def create_tissuenet_splits():
20+
output_root = "/scratch-grete/projects/nim00007/sam/ood/LM/tissuenet"
21+
22+
def write_split(images, labels, split):
23+
out_folder = os.path.join(output_root, split)
24+
os.makedirs(out_folder, exist_ok=True)
25+
for ii, (im, lab) in enumerate(zip(images, labels)):
26+
out_im = os.path.join(out_folder, f"image_{ii:04}.tif")
27+
out_lab = os.path.join(out_folder, f"labels_{ii:04}.tif")
28+
im, lab = imageio.imread(im), imageio.imread(lab)
29+
imageio.imwrite(out_im, im)
30+
imageio.imwrite(out_lab, lab)
31+
32+
val_set = get_tissuenet_images("val")
33+
34+
write_split(val_images, val_labels, "val")
35+
write_split(test_images, test_labels, "test")
36+
37+
38+
if __name__ == "__main__":
39+
create_tissuenet_splits()

0 commit comments

Comments
 (0)