Skip to content

Commit 19ae2c8

Browse files
Update generalist experiments
1 parent d99f261 commit 19ae2c8

File tree

7 files changed

+216
-212
lines changed

7 files changed

+216
-212
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
import os
3+
from tqdm import tqdm
4+
import imageio.v2 as imageio
5+
import numpy as np
6+
7+
from torch_em.data import MinInstanceSampler
8+
from torch_em.transform.label import label_consecutive
9+
from torch_em.data.datasets import get_tissuenet_loader
10+
from torch_em.transform.raw import standardize, normalize_percentile
11+
12+
13+
def rgb_to_gray_transform(raw):
14+
raw = normalize_percentile(raw, axis=(1, 2))
15+
raw = np.mean(raw, axis=0)
16+
raw = standardize(raw)
17+
return raw
18+
19+
20+
def get_tissuenet_loaders(input_path):
21+
sampler = MinInstanceSampler()
22+
label_transform = label_consecutive
23+
raw_transform = rgb_to_gray_transform
24+
val_loader = get_tissuenet_loader(path=input_path, split="val", raw_channel="rgb", label_channel="cell",
25+
batch_size=1, patch_shape=(256, 256), num_workers=0,
26+
sampler=sampler, label_transform=label_transform, raw_transform=raw_transform)
27+
test_loader = get_tissuenet_loader(path=input_path, split="test", raw_channel="rgb", label_channel="cell",
28+
batch_size=1, patch_shape=(256, 256), num_workers=0,
29+
sampler=sampler, label_transform=label_transform, raw_transform=raw_transform)
30+
return val_loader, test_loader
31+
32+
33+
def extract_images(loader, out_folder):
34+
os.makedirs(out_folder, exist_ok=True)
35+
for i, (x, y) in tqdm(enumerate(loader), total=len(loader)):
36+
img_path = os.path.join(out_folder, "image_{:04d}.tif".format(i))
37+
gt_path = os.path.join(out_folder, "label_{:04d}.tif".format(i))
38+
39+
img = x.squeeze().detach().cpu().numpy()
40+
gt = y.squeeze().detach().cpu().numpy()
41+
42+
imageio.imwrite(img_path, img)
43+
imageio.imwrite(gt_path, gt)
44+
45+
46+
def main():
47+
val_loader, test_loader = get_tissuenet_loaders("/scratch-grete/projects/nim00007/data/tissuenet")
48+
print("Length of val loader is:", len(val_loader))
49+
print("Length of test loader is:", len(test_loader))
50+
51+
root_save_dir = "/scratch/projects/nim00007/sam/datasets/tissuenet"
52+
53+
# we use the val set for test because there are some issues with the test set
54+
# out_folder = os.path.join(root_save_dir, "test")
55+
# extract_images(val_loader, out_folder)
56+
57+
# we use the test folder for val and just use as many images as we can sample
58+
out_folder = os.path.join(root_save_dir, "val")
59+
extract_images(test_loader, out_folder)
60+
61+
62+
if __name__ == "__main__":
63+
main()

finetuning/generalists/data/precompute_prompts.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

finetuning/generalists/data/prepare_deepbacs_data.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

finetuning/generalists/data/prepare_tissuenet_data.py

Lines changed: 0 additions & 39 deletions
This file was deleted.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import argparse
2+
import os
3+
4+
from subprocess import run
5+
6+
import micro_sam.evaluation as evaluation
7+
from util import get_data_paths, ALL_DATASETS
8+
9+
PROMPT_ROOT = "/scratch/projects/nim00007/sam/experiments/prompts"
10+
11+
12+
def precompute_prompts(dataset):
13+
# everything for livecell has been computed already
14+
if dataset == "livecell":
15+
return
16+
17+
prompt_folder = os.path.join(PROMPT_ROOT, dataset)
18+
_, gt_paths = get_data_paths(dataset, "test")
19+
20+
settings = evaluation.default_experiment_settings()
21+
evaluation.precompute_all_prompts(gt_paths, prompt_folder, settings)
22+
23+
24+
def precompute_prompts_slurm(job_id):
25+
dataset = ALL_DATASETS[job_id]
26+
precompute_prompts(dataset)
27+
28+
29+
def submit_array_job():
30+
n_datasets = len(ALL_DATASETS)
31+
cmd = ["sbatch", "-a", f"0-{n_datasets-1}", "precompute_prompts.sbatch"]
32+
run(cmd)
33+
34+
35+
def main():
36+
parser = argparse.ArgumentParser()
37+
parser.add_argument("-d", "--dataset")
38+
args = parser.parse_args()
39+
if args.dataset is not None:
40+
precompute_prompts(args.dataset)
41+
return
42+
43+
# this will fail if the dataset is invalid
44+
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)
45+
46+
if job_id is None: # this is the main script that submits slurm jobs
47+
submit_array_job()
48+
else: # we're in a slurm job and precompute a setting
49+
job_id = int(job_id)
50+
precompute_prompts_slurm(job_id)
51+
52+
53+
if __name__ == "__main__":
54+
main()

finetuning/generalists/data/precompute_prompts.sbatch renamed to finetuning/generalists/precompute_prompts.sbatch

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#SBATCH -t 720
55
#SBATCH -p grete:shared
66
#SBATCH -G A100:1
7+
#SBATCH -A nim00007
78

89
source activate sam
910
python precompute_prompts.py $@

0 commit comments

Comments
 (0)